(improvement)(chat) Obtain similar query from ExemplarService instead of directly from embedding store (#1278)

Co-authored-by: lxwcodemonkey
This commit is contained in:
LXW
2024-06-29 14:02:39 +08:00
committed by GitHub
parent a1083a92c2
commit a45fe183d2
4 changed files with 10 additions and 122 deletions

View File

@@ -6,8 +6,10 @@ import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryReviewResult;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.ToString;
import java.util.Date;
@@ -15,6 +17,8 @@ import java.util.Date;
@Data
@Builder
@ToString
@AllArgsConstructor
@NoArgsConstructor
@TableName("s2_chat_memory")
public class ChatMemoryDO {
@TableId(type = IdType.AUTO)

View File

@@ -2,23 +2,18 @@ package com.tencent.supersonic.chat.server.processor.parse;
import com.alibaba.fastjson.JSONObject;
import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.PageInfo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.SimilarQueryManager;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
@@ -44,25 +39,10 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
}
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
//1. recall solved query by queryText
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
if (CollectionUtils.isEmpty(similarQueries)) {
return Lists.newArrayList();
}
//2. remove low score query
List<Long> queryIds = similarQueries.stream()
.map(SimilarQueryRecallResp::getQueryId).collect(Collectors.toList());
int lowScoreThreshold = 3;
List<QueryResp> queryResps = getChatQuery(queryIds);
if (CollectionUtils.isEmpty(queryResps)) {
return Lists.newArrayList();
}
Set<Long> lowScoreQueryIds = queryResps.stream().filter(queryResp ->
queryResp.getScore() != null && queryResp.getScore() <= lowScoreThreshold)
.map(QueryResp::getQuestionId).collect(Collectors.toSet());
return similarQueries.stream().filter(solvedQueryRecallResp ->
!lowScoreQueryIds.contains(solvedQueryRecallResp.getQueryId()))
ExemplarService exemplarService = ContextUtils.getBean(ExemplarService.class);
List<SqlExemplar> exemplars = exemplarService.recallExemplars(agentId.toString(), queryText, 5);
return exemplars.stream().map(sqlExemplar ->
SimilarQueryRecallResp.builder().queryText(sqlExemplar.getQuestion()).build())
.collect(Collectors.toList());
}
@@ -71,16 +51,6 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
return chatQueryRepository.getChatQueryDO(queryId);
}
private List<QueryResp> getChatQuery(List<Long> queryIds) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
PageQueryInfoReq pageQueryInfoReq = new PageQueryInfoReq();
pageQueryInfoReq.setIds(queryIds);
pageQueryInfoReq.setPageSize(100);
pageQueryInfoReq.setCurrent(1);
PageInfo<QueryResp> queryRespPageInfo = chatQueryRepository.getChatQuery(pageQueryInfoReq, null);
return queryRespPageInfo.getList();
}
private void updateChatQuery(ChatQueryDO chatQueryDO) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
UpdateWrapper<ChatQueryDO> updateWrapper = new UpdateWrapper<>();

View File

@@ -1,80 +0,0 @@
package com.tencent.supersonic.chat.server.util;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
@Slf4j
@Component
public class SimilarQueryManager {
private EmbeddingConfig embeddingConfig;
@Autowired
private EmbeddingService embeddingService;
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
this.embeddingConfig = embeddingConfig;
}
public List<SimilarQueryRecallResp> recallSimilarQuery(String queryText, Integer agentId) {
List<SimilarQueryRecallResp> similarQueryRecallResps = Lists.newArrayList();
try {
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
Map<String, String> filterCondition = new HashMap<>();
filterCondition.put("agentId", String.valueOf(agentId));
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
.queryTextsList(Lists.newArrayList(queryText))
.filterCondition(filterCondition)
.build();
List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(solvedQueryCollection, retrieveQuery,
solvedQueryResultNum * 20);
log.info("[embedding] recognize result body:{}", resultList);
Set<String> querySet = new HashSet<>();
if (CollectionUtils.isNotEmpty(resultList)) {
for (RetrieveQueryResult retrieveQueryResult : resultList) {
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
for (Retrieval retrieval : retrievals) {
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
continue;
}
if (querySet.contains(retrieval.getQuery())) {
continue;
}
String id = retrieval.getId();
SimilarQueryRecallResp similarQueryRecallResp = SimilarQueryRecallResp.builder()
.queryText(retrieval.getQuery())
.queryId(Long.parseLong(id))
.build();
similarQueryRecallResps.add(similarQueryRecallResp);
querySet.add(retrieval.getQuery());
}
}
}
} catch (Exception e) {
log.warn("recall similar solved query failed, queryText:{}", queryText, e);
}
return similarQueryRecallResps.stream()
.limit(embeddingConfig.getSolvedQueryResultNum()).collect(Collectors.toList());
}
}

View File

@@ -16,12 +16,6 @@ public class EmbeddingConfig {
@Value("${s2.embedding.nResult:1}")
private int nResult;
@Value("${s2.embedding.solved.query.collection:solved_query_collection}")
private String solvedQueryCollection;
@Value("${s2.embedding.solved.query.nResult:5}")
private int solvedQueryResultNum;
@Value("${s2.embedding.metric.analyzeQuery.collection:solved_query_collection}")
private String metricAnalyzeQueryCollection;