mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) Obtain similar query from ExemplarService instead of directly from embedding store (#1278)
Co-authored-by: lxwcodemonkey
This commit is contained in:
@@ -6,8 +6,10 @@ import com.baomidou.mybatisplus.annotation.TableId;
|
||||
import com.baomidou.mybatisplus.annotation.TableName;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.Date;
|
||||
@@ -15,6 +17,8 @@ import java.util.Date;
|
||||
@Data
|
||||
@Builder
|
||||
@ToString
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@TableName("s2_chat_memory")
|
||||
public class ChatMemoryDO {
|
||||
@TableId(type = IdType.AUTO)
|
||||
|
||||
@@ -2,23 +2,18 @@ package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@@ -44,25 +39,10 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
}
|
||||
|
||||
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
|
||||
//1. recall solved query by queryText
|
||||
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
|
||||
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
|
||||
if (CollectionUtils.isEmpty(similarQueries)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
//2. remove low score query
|
||||
List<Long> queryIds = similarQueries.stream()
|
||||
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
|
||||
int lowScoreThreshold = 3;
|
||||
List<QueryResp> queryResps = getChatQuery(queryIds);
|
||||
if (CollectionUtils.isEmpty(queryResps)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
|
||||
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
|
||||
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
|
||||
return similarQueries.stream().filter(solvedQueryRecallResp ->
|
||||
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
|
||||
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
|
||||
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5);
|
||||
return exemplars.stream().map(sqlExemplar ->
|
||||
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@@ -71,16 +51,6 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
return chatQueryRepository.getChatQueryDO(queryId);
|
||||
}
|
||||
|
||||
private List<QueryResp> getChatQuery(List<Long> queryIds) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
|
||||
pageQueryInfoReq.setIds(queryIds);
|
||||
pageQueryInfoReq.setPageSize(100);
|
||||
pageQueryInfoReq.setCurrent(1);
|
||||
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
|
||||
return queryRespPageInfo.getList();
|
||||
}
|
||||
|
||||
private void updateChatQuery(ChatQueryDO chatQueryDO) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
UpdateWrapper<ChatQueryDO> updateWrapper = new UpdateWrapper<>();
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SimilarQueryManager {
|
||||
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
|
||||
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
|
||||
this.embeddingConfig = embeddingConfig;
|
||||
}
|
||||
|
||||
public List<SimilarQueryRecallResp> recallSimilarQuery(String queryText, Integer agentId) {
|
||||
List<SimilarQueryRecallResp> similarQueryRecallResps = Lists.newArrayList();
|
||||
try {
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
|
||||
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("agentId", String.valueOf(agentId));
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(Lists.newArrayList(queryText))
|
||||
.filterCondition(filterCondition)
|
||||
.build();
|
||||
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||
solvedQueryResultNum * 20);
|
||||
|
||||
log.info("[embedding] recognize result body:{}", resultList);
|
||||
Set<String> querySet = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||
for (RetrieveQueryResult retrieveQueryResult : resultList) {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
if (querySet.contains(retrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
String id = retrieval.getId();
|
||||
SimilarQueryRecallResp similarQueryRecallResp = SimilarQueryRecallResp.builder()
|
||||
.queryText(retrieval.getQuery())
|
||||
.queryId(Long.parseLong(id))
|
||||
.build();
|
||||
similarQueryRecallResps.add(similarQueryRecallResp);
|
||||
querySet.add(retrieval.getQuery());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.warn("recall similar solved query failed, queryText:{}", queryText, e);
|
||||
}
|
||||
return similarQueryRecallResps.stream()
|
||||
.limit(embeddingConfig.getSolvedQueryResultNum()).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
@@ -16,12 +16,6 @@ public class EmbeddingConfig {
|
||||
@Value("${s2.embedding.nResult:1}")
|
||||
private int nResult;
|
||||
|
||||
@Value("${s2.embedding.solved.query.collection:solved_query_collection}")
|
||||
private String solvedQueryCollection;
|
||||
|
||||
@Value("${s2.embedding.solved.query.nResult:5}")
|
||||
private int solvedQueryResultNum;
|
||||
|
||||
@Value("${s2.embedding.metric.analyzeQuery.collection:solved_query_collection}")
|
||||
private String metricAnalyzeQueryCollection;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user