diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java index fea7ee38f..963538d49 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/ChatMemoryDO.java @@ -6,8 +6,10 @@ import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.TableName; import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult; import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import lombok.ToString; import java.util.Date; @@ -15,6 +17,8 @@ import java.util.Date; @Data @Builder @ToString +@AllArgsConstructor +@NoArgsConstructor @TableName("s2_chat_memory") public class ChatMemoryDO { @TableId(type = IdType.AUTO) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index ad93bf790..7a95d237f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -2,23 +2,18 @@ package com.tencent.supersonic.chat.server.processor.parse; import com.alibaba.fastjson.JSONObject; import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; -import com.github.pagehelper.PageInfo; -import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; -import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; -import com.tencent.supersonic.chat.server.util.SimilarQueryManager; +import com.tencent.supersonic.common.pojo.SqlExemplar; +import com.tencent.supersonic.common.service.ExemplarService; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; import java.util.List; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; @@ -44,25 +39,10 @@ public class QueryRecommendProcessor implements ParseResultProcessor { } public List getSimilarQueries(String queryText, Integer agentId) { - //1. recall solved query by queryText - SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class); - List similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId); - if (CollectionUtils.isEmpty(similarQueries)) { - return Lists.newArrayList(); - } - //2. remove low score query - List queryIds = similarQueries.stream() - .map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList()); - int lowScoreThreshold = 3; - List queryResps = getChatQuery(queryIds); - if (CollectionUtils.isEmpty(queryResps)) { - return Lists.newArrayList(); - } - Set lowScoreQueryIds = queryResps.stream().filter(queryResp -> - queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold) - .map(QueryResp::getQuestionId).collect(Collectors.toSet()); - return similarQueries.stream().filter(solvedQueryRecallResp -> - !lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId())) + ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class); + List exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5); + return exemplars.stream().map(sqlExemplar -> + SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build()) .collect(Collectors.toList()); } @@ -71,16 +51,6 @@ public class QueryRecommendProcessor implements ParseResultProcessor { return chatQueryRepository.getChatQueryDO(queryId); } - private List getChatQuery(List queryIds) { - ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); - PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq(); - pageQueryInfoReq.setIds(queryIds); - pageQueryInfoReq.setPageSize(100); - pageQueryInfoReq.setCurrent(1); - PageInfo queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null); - return queryRespPageInfo.getList(); - } - private void updateChatQuery(ChatQueryDO chatQueryDO) { ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); UpdateWrapper updateWrapper = new UpdateWrapper<>(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java deleted file mode 100644 index 5c1dc4cf0..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/SimilarQueryManager.java +++ /dev/null @@ -1,80 +0,0 @@ -package com.tencent.supersonic.chat.server.util; - -import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; -import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.service.EmbeddingService; -import dev.langchain4j.store.embedding.Retrieval; -import dev.langchain4j.store.embedding.RetrieveQuery; -import dev.langchain4j.store.embedding.RetrieveQueryResult; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections.CollectionUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.stereotype.Component; - -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - -@Slf4j -@Component -public class SimilarQueryManager { - - private EmbeddingConfig embeddingConfig; - - @Autowired - private EmbeddingService embeddingService; - - - public SimilarQueryManager(EmbeddingConfig embeddingConfig) { - this.embeddingConfig = embeddingConfig; - } - - public List recallSimilarQuery(String queryText, Integer agentId) { - List similarQueryRecallResps = Lists.newArrayList(); - try { - String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection(); - int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum(); - - Map filterCondition = new HashMap<>(); - filterCondition.put("agentId", String.valueOf(agentId)); - RetrieveQuery retrieveQuery = RetrieveQuery.builder() - .queryTextsList(Lists.newArrayList(queryText)) - .filterCondition(filterCondition) - .build(); - List resultList = embeddingService.retrieveQuery(solvedQueryCollection, retrieveQuery, - solvedQueryResultNum * 20); - - log.info("[embedding] recognize result body:{}", resultList); - Set querySet = new HashSet<>(); - if (CollectionUtils.isNotEmpty(resultList)) { - for (RetrieveQueryResult retrieveQueryResult : resultList) { - List retrievals = retrieveQueryResult.getRetrieval(); - for (Retrieval retrieval : retrievals) { - if (queryText.equalsIgnoreCase(retrieval.getQuery())) { - continue; - } - if (querySet.contains(retrieval.getQuery())) { - continue; - } - String id = retrieval.getId(); - SimilarQueryRecallResp similarQueryRecallResp = SimilarQueryRecallResp.builder() - .queryText(retrieval.getQuery()) - .queryId(Long.parseLong(id)) - .build(); - similarQueryRecallResps.add(similarQueryRecallResp); - querySet.add(retrieval.getQuery()); - } - } - } - - } catch (Exception e) { - log.warn("recall similar solved query failed, queryText:{}", queryText, e); - } - return similarQueryRecallResps.stream() - .limit(embeddingConfig.getSolvedQueryResultNum()).collect(Collectors.toList()); - } -} diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java index f2a272bc8..6b47eb4cc 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingConfig.java @@ -16,12 +16,6 @@ public class EmbeddingConfig { @Value("${s2.embedding.nResult:1}") private int nResult; - @Value("${s2.embedding.solved.query.collection:solved_query_collection}") - private String solvedQueryCollection; - - @Value("${s2.embedding.solved.query.nResult:5}") - private int solvedQueryResultNum; - @Value("${s2.embedding.metric.analyzeQuery.collection:solved_query_collection}") private String metricAnalyzeQueryCollection;