(improvement)(chat) Upgrade and optimize the embedding metastore. (#1198)

This commit is contained in:
lexluo09
2024-06-23 21:46:10 +08:00
committed by GitHub
parent 2ae94fb38c
commit 4d6cbf31f7
46 changed files with 3788 additions and 498 deletions

View File

@@ -14,13 +14,12 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import dev.langchain4j.store.embedding.ComponentFactory;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.S2EmbeddingStore;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
@@ -49,7 +48,8 @@ public class PluginManager {
@Autowired
private EmbeddingConfig embeddingConfig;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingService embeddingService;
public static List<Plugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
@@ -122,7 +122,7 @@ public class PluginManager {
embeddingQuery.setQueryId(id);
queries.add(embeddingQuery);
}
s2EmbeddingStore.deleteQuery(presetCollection, queries);
embeddingService.deleteQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
@@ -130,7 +130,7 @@ public class PluginManager {
return;
}
String presetCollection = embeddingConfig.getPresetCollection();
s2EmbeddingStore.addQuery(presetCollection, queries);
embeddingService.addQuery(presetCollection, queries);
}
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
@@ -143,7 +143,7 @@ public class PluginManager {
.queryTextsList(Collections.singletonList(embeddingText))
.build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(embeddingConfig.getPresetCollection(),
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(embeddingConfig.getPresetCollection(),
retrieveQuery, embeddingConfig.getNResult());
if (CollectionUtils.isNotEmpty(resultList)) {

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
@@ -13,6 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
import java.util.Objects;
import org.springframework.util.CollectionUtils;
import java.util.Collections;
@@ -66,8 +68,13 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
}
for (Retrieval retrieval : retrievals) {
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
SchemaElement.class);
if (Objects.nonNull(retrieval.getMetadata().get("id"))) {
String idStr = retrieval.getMetadata().get("id").toString()
.replaceAll(DictWordType.NATURE_SPILT, "");
retrieval.getMetadata().put("id", idStr);
}
String metaStr = JSONObject.toJSONString(retrieval.getMetadata());
SchemaElement schemaElement = JSONObject.parseObject(metaStr, SchemaElement.class);
if (retrieval.getMetadata().containsKey("dataSetId")) {
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
.replace(Constants.UNDERLINE, "");

View File

@@ -4,16 +4,24 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import dev.langchain4j.store.embedding.ComponentFactory;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.EmbeddingQuery;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import dev.langchain4j.store.embedding.S2EmbeddingStore;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.util.Strings;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
@@ -24,22 +32,14 @@ import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;
import java.net.URI;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Component
public class SimilarQueryManager {
private EmbeddingConfig embeddingConfig;
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
@Autowired
private EmbeddingService embeddingService;
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
@@ -60,7 +60,7 @@ public class SimilarQueryManager {
metaData.put("agentId", similarQueryReq.getAgentId());
embeddingQuery.setMetadata(metaData);
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
} catch (Exception e) {
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
}
@@ -81,7 +81,7 @@ public class SimilarQueryManager {
.queryTextsList(Lists.newArrayList(queryText))
.filterCondition(filterCondition)
.build();
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery,
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(solvedQueryCollection, retrieveQuery,
solvedQueryResultNum * 20);
log.info("[embedding] recognize result body:{}", resultList);