mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 13:47:09 +00:00
(improvement)(chat) Upgrade and optimize the embedding metastore. (#1198)
This commit is contained in:
@@ -14,13 +14,12 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import dev.langchain4j.store.embedding.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
@@ -49,7 +48,8 @@ public class PluginManager {
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
public static List<Plugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
@@ -122,7 +122,7 @@ public class PluginManager {
|
||||
embeddingQuery.setQueryId(id);
|
||||
queries.add(embeddingQuery);
|
||||
}
|
||||
s2EmbeddingStore.deleteQuery(presetCollection, queries);
|
||||
embeddingService.deleteQuery(presetCollection, queries);
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginAdd(List<EmbeddingQuery> queries) {
|
||||
@@ -130,7 +130,7 @@ public class PluginManager {
|
||||
return;
|
||||
}
|
||||
String presetCollection = embeddingConfig.getPresetCollection();
|
||||
s2EmbeddingStore.addQuery(presetCollection, queries);
|
||||
embeddingService.addQuery(presetCollection, queries);
|
||||
}
|
||||
|
||||
public void requestEmbeddingPluginAddALL(List<Plugin> plugins) {
|
||||
@@ -143,7 +143,7 @@ public class PluginManager {
|
||||
.queryTextsList(Collections.singletonList(embeddingText))
|
||||
.build();
|
||||
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(embeddingConfig.getPresetCollection(),
|
||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(embeddingConfig.getPresetCollection(),
|
||||
retrieveQuery, embeddingConfig.getNResult());
|
||||
|
||||
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.processor.execute;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
@@ -13,6 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.chat.knowledge.MetaEmbeddingService;
|
||||
import java.util.Objects;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
@@ -66,8 +68,13 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
}
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
|
||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(retrieval.getMetadata()),
|
||||
SchemaElement.class);
|
||||
if (Objects.nonNull(retrieval.getMetadata().get("id"))) {
|
||||
String idStr = retrieval.getMetadata().get("id").toString()
|
||||
.replaceAll(DictWordType.NATURE_SPILT, "");
|
||||
retrieval.getMetadata().put("id", idStr);
|
||||
}
|
||||
String metaStr = JSONObject.toJSONString(retrieval.getMetadata());
|
||||
SchemaElement schemaElement = JSONObject.parseObject(metaStr, SchemaElement.class);
|
||||
if (retrieval.getMetadata().containsKey("dataSetId")) {
|
||||
String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
|
||||
.replace(Constants.UNDERLINE, "");
|
||||
|
||||
@@ -4,16 +4,24 @@ import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import dev.langchain4j.store.embedding.ComponentFactory;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.store.embedding.EmbeddingQuery;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import dev.langchain4j.store.embedding.S2EmbeddingStore;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
@@ -24,22 +32,14 @@ import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SimilarQueryManager {
|
||||
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
|
||||
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
|
||||
@@ -60,7 +60,7 @@ public class SimilarQueryManager {
|
||||
metaData.put("agentId", similarQueryReq.getAgentId());
|
||||
embeddingQuery.setMetadata(metaData);
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||
embeddingService.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||
} catch (Exception e) {
|
||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||
}
|
||||
@@ -81,7 +81,7 @@ public class SimilarQueryManager {
|
||||
.queryTextsList(Lists.newArrayList(queryText))
|
||||
.filterCondition(filterCondition)
|
||||
.build();
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||
solvedQueryResultNum * 20);
|
||||
|
||||
log.info("[embedding] recognize result body:{}", resultList);
|
||||
|
||||
Reference in New Issue
Block a user