mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
package com.tencent.supersonic.common.service;
|
||||
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
@@ -13,9 +14,9 @@ public interface EmbeddingService {
|
||||
|
||||
void addCollection(String collectionName);
|
||||
|
||||
void addQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||
void addQuery(String collectionName, List<TextSegment> queries);
|
||||
|
||||
void deleteQuery(String collectionName, List<EmbeddingQuery> queries);
|
||||
void deleteQuery(String collectionName, List<TextSegment> queries);
|
||||
|
||||
List<RetrieveQueryResult> retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, int num);
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ package com.tencent.supersonic.common.service.impl;
|
||||
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.data.embedding.Embedding;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.model.embedding.EmbeddingModel;
|
||||
import dev.langchain4j.store.embedding.EmbeddingMatch;
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
||||
import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
@@ -12,8 +12,13 @@ import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.TextSegmentConvert;
|
||||
import dev.langchain4j.store.embedding.filter.Filter;
|
||||
import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
@@ -21,9 +26,6 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service
|
||||
public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
@@ -39,17 +41,17 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
public void addQuery(String collectionName, List<TextSegment> queries) {
|
||||
EmbeddingStore embeddingStore = embeddingStoreFactory.create(collectionName);
|
||||
for (EmbeddingQuery query : queries) {
|
||||
String question = query.getQuery();
|
||||
for (TextSegment query : queries) {
|
||||
String question = query.text();
|
||||
Embedding embedding = embeddingModel.embed(question).content();
|
||||
embeddingStore.add(embedding, query);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void deleteQuery(String collectionName, List<EmbeddingQuery> queries) {
|
||||
public void deleteQuery(String collectionName, List<TextSegment> queries) {
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -66,21 +68,21 @@ public class EmbeddingServiceImpl implements EmbeddingService {
|
||||
.queryEmbedding(embeddedText).filter(filter).maxResults(num).build();
|
||||
|
||||
EmbeddingSearchResult result = embeddingStore.search(request);
|
||||
List<EmbeddingMatch<EmbeddingQuery>> relevant = result.matches();
|
||||
List<EmbeddingMatch<TextSegment>> relevant = result.matches();
|
||||
|
||||
RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult();
|
||||
retrieveQueryResult.setQuery(queryText);
|
||||
List<Retrieval> retrievals = new ArrayList<>();
|
||||
for (EmbeddingMatch<EmbeddingQuery> embeddingMatch : relevant) {
|
||||
for (EmbeddingMatch<TextSegment> embeddingMatch : relevant) {
|
||||
Retrieval retrieval = new Retrieval();
|
||||
EmbeddingQuery embedded = embeddingMatch.embedded();
|
||||
TextSegment embedded = embeddingMatch.embedded();
|
||||
retrieval.setDistance(1 - embeddingMatch.score());
|
||||
retrieval.setId(embedded.getQueryId());
|
||||
retrieval.setQuery(embedded.getQuery());
|
||||
retrieval.setId(TextSegmentConvert.getQueryId(embedded));
|
||||
retrieval.setQuery(embedded.text());
|
||||
Map<String, Object> metadata = new HashMap<>();
|
||||
if (Objects.nonNull(embedded)
|
||||
&& MapUtils.isNotEmpty(embedded.getMetadata())) {
|
||||
metadata.putAll(embedded.getMetadata());
|
||||
&& MapUtils.isNotEmpty(embedded.metadata().toMap())) {
|
||||
metadata.putAll(embedded.metadata().toMap());
|
||||
}
|
||||
retrieval.setMetadata(metadata);
|
||||
retrievals.add(retrieval);
|
||||
|
||||
@@ -1,21 +1,21 @@
|
||||
package dev.langchain4j.chroma.spring;
|
||||
|
||||
|
||||
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
|
||||
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
import static dev.langchain4j.chroma.spring.Properties.PREFIX;
|
||||
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(Properties.class)
|
||||
public class ChromaAutoConfig {
|
||||
|
||||
@Bean
|
||||
@ConditionalOnProperty(PREFIX + ".embedding-store.base-url")
|
||||
EmbeddingStoreFactory milvusChatModel(Properties properties) {
|
||||
EmbeddingStoreFactory chromaChatModel(Properties properties) {
|
||||
return new ChromaEmbeddingStoreFactory(properties);
|
||||
}
|
||||
}
|
||||
@@ -17,7 +17,7 @@ public class ChromaEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
EmbeddingStoreProperties embeddingStore = properties.getEmbeddingStore();
|
||||
return ChromaEmbeddingStore.builder()
|
||||
.baseUrl(embeddingStore.getBaseUrl())
|
||||
.collectionName(embeddingStore.getCollectionName())
|
||||
.collectionName(collectionName)
|
||||
.timeout(embeddingStore.getTimeout())
|
||||
.build();
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package dev.langchain4j.chroma.spring;
|
||||
|
||||
import java.time.Duration;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.time.Duration;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
class EmbeddingStoreProperties {
|
||||
|
||||
private String baseUrl;
|
||||
private String collectionName;
|
||||
private Duration timeout;
|
||||
}
|
||||
@@ -1,18 +1,19 @@
|
||||
package dev.langchain4j.inmemory.spring;
|
||||
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStore;
|
||||
import dev.langchain4j.store.embedding.EmbeddingStoreFactory;
|
||||
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
|
||||
private static Map<String, InMemoryEmbeddingStore<EmbeddingQuery>> collectionNameToStore =
|
||||
private static Map<String, InMemoryEmbeddingStore<TextSegment>> collectionNameToStore =
|
||||
new ConcurrentHashMap<>();
|
||||
private Properties properties;
|
||||
|
||||
@@ -23,14 +24,12 @@ public class InMemoryEmbeddingStoreFactory implements EmbeddingStoreFactory {
|
||||
|
||||
@Override
|
||||
public synchronized EmbeddingStore create(String collectionName) {
|
||||
InMemoryEmbeddingStore<EmbeddingQuery> embeddingStore = collectionNameToStore.get(collectionName);
|
||||
InMemoryEmbeddingStore<TextSegment> embeddingStore = collectionNameToStore.get(collectionName);
|
||||
if (Objects.nonNull(embeddingStore)) {
|
||||
return embeddingStore;
|
||||
}
|
||||
if (Objects.isNull(embeddingStore)) {
|
||||
embeddingStore = new InMemoryEmbeddingStore();
|
||||
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
||||
}
|
||||
embeddingStore = new InMemoryEmbeddingStore();
|
||||
collectionNameToStore.putIfAbsent(collectionName, embeddingStore);
|
||||
return embeddingStore;
|
||||
|
||||
}
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class EmbeddingQuery {
|
||||
|
||||
private String queryId;
|
||||
|
||||
private String query;
|
||||
|
||||
private Map<String, Object> metadata;
|
||||
|
||||
private List<Double> queryEmbedding;
|
||||
public static List<EmbeddingQuery> convertToEmbedding(List<DataItem> dataItems) {
|
||||
return dataItems.stream().map(dataItem -> {
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(
|
||||
dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
embeddingQuery.setQuery(dataItem.getName());
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
embeddingQuery.setMetadata(meta);
|
||||
embeddingQuery.setQueryEmbedding(null);
|
||||
return embeddingQuery;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package dev.langchain4j.store.embedding;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.common.pojo.DataItem;
|
||||
import dev.langchain4j.data.document.Metadata;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class TextSegmentConvert {
|
||||
|
||||
public static final String QUERY_ID = "queryId";
|
||||
|
||||
public static List<TextSegment> convertToEmbedding(List<DataItem> dataItems) {
|
||||
return dataItems.stream().map(dataItem -> {
|
||||
Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class);
|
||||
TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta));
|
||||
addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase());
|
||||
return textSegment;
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
public static void addQueryId(TextSegment textSegment, String queryId) {
|
||||
textSegment.metadata().put(QUERY_ID, queryId);
|
||||
}
|
||||
|
||||
public static String getQueryId(TextSegment textSegment) {
|
||||
if (Objects.isNull(textSegment) || Objects.isNull(textSegment.metadata())) {
|
||||
return null;
|
||||
}
|
||||
return textSegment.metadata().get(QUERY_ID);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user