(improvement)(chat) The embedding model will be uniformly adopted using the textSegment and will be compatible with the queryId parameter. (#1202)

This commit is contained in:
lexluo09
2024-06-24 13:27:03 +08:00
committed by GitHub
parent a7d367baa3
commit 4b288d9815
13 changed files with 134 additions and 127 deletions

View File

@@ -14,16 +14,17 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.tuple.Pair;
@@ -116,16 +117,16 @@ public class PluginManager {
}
String presetCollection = embeddingConfig.getPresetCollection();
List<EmbeddingQuery> queries = new ArrayList<>();
List<TextSegment> queries = new ArrayList<>();
for (String id : queryIds) {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(id);
queries.add(embeddingQuery);
TextSegment query = TextSegment.from("");
TextSegmentConvert.addQueryId(query, id);
queries.add(query);
}
embeddingService.deleteQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
public void requestEmbeddingPluginAdd(List<TextSegment> queries) {
if (CollectionUtils.isEmpty(queries)) {
return;
}
@@ -133,10 +134,6 @@ public class PluginManager {
embeddingService.addQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
requestEmbeddingPluginAdd(convert(plugins));
}
public RetrieveQueryResult recognize(String embeddingText) {
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
@@ -158,15 +155,14 @@ public class PluginManager {
throw new RuntimeException("get embedding result failed");
}
public List<EmbeddingQuery> convert(List<Plugin> plugins) {
List<EmbeddingQuery> queries = Lists.newArrayList();
public List<TextSegment> convert(List<Plugin> plugins) {
List<TextSegment> queries = Lists.newArrayList();
for (Plugin plugin : plugins) {
List<String> exampleQuestions = plugin.getExampleQuestionList();
int num = 0;
for (String pattern : exampleQuestions) {
EmbeddingQuery query = new EmbeddingQuery();
query.setQueryId(generateUniqueEmbeddingId(num, plugin.getId()));
query.setQuery(pattern);
TextSegment query = TextSegment.from(pattern);
TextSegmentConvert.addQueryId(query, generateUniqueEmbeddingId(num, plugin.getId()));
queries.add(query);
num++;
}
@@ -176,8 +172,8 @@ public class PluginManager {
private Set<String> getEmbeddingId(List<Plugin> plugins) {
Set<String> embeddingIdSet = new HashSet<>();
for (EmbeddingQuery query : convert(plugins)) {
embeddingIdSet.add(query.getQueryId());
for (TextSegment query : convert(plugins)) {
TextSegmentConvert.addQueryId(query, TextSegmentConvert.getQueryId(query));
}
return embeddingIdSet;
}

View File

@@ -5,18 +5,12 @@ import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import dev.langchain4j.store.embedding.TextSegmentConvert;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
@@ -32,6 +26,15 @@ import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Component
public class SimilarQueryManager {
@@ -52,15 +55,13 @@ public class SimilarQueryManager {
}
String queryText = similarQueryReq.getQueryText();
try {
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
embeddingQuery.setQueryId(String.valueOf(similarQueryReq.getQueryId()));
embeddingQuery.setQuery(queryText);
Map<String, Object> metaData = new HashMap<>();
metaData.put("agentId", similarQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData);
TextSegment textSegment = TextSegment.from(queryText, new Metadata(metaData));
TextSegmentConvert.addQueryId(textSegment, String.valueOf(similarQueryReq.getQueryId()));
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(textSegment));
} catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
}