(improvement)(chat) Remove QueryReq parameter from QueryContext. (#656)

This commit is contained in:
lexluo09
2024-01-19 16:17:31 +08:00
committed by GitHub
parent f017f41201
commit cbf38ed785
35 changed files with 115 additions and 152 deletions

View File

@@ -1,15 +1,14 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import java.util.List;
public interface ChatQueryRepository {
@@ -22,15 +21,12 @@ public interface ChatQueryRepository {
List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
void updateChatParseInfo(List<ChatParseDO> chatParseDOS);
ChatQueryDO getLastChatQuery(long chatId);
int updateChatQuery(ChatQueryDO chatQueryDO);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses);
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, List<SemanticParseInfo> candidateParses);
ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -3,35 +3,35 @@ package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample.Criteria;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatParseMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatQueryDOMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCustomMapper;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.common.util.PageUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
@Repository
@Primary
@@ -116,13 +116,13 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return queryResp;
}
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryReq queryReq) {
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryContext queryContext) {
ChatQueryDO chatQueryDO = new ChatQueryDO();
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatQueryDO.setCreateTime(new java.util.Date());
chatQueryDO.setUserName(queryReq.getUser().getName());
chatQueryDO.setQueryText(queryReq.getQueryText());
chatQueryDO.setAgentId(queryReq.getAgentId());
chatQueryDO.setUserName(queryContext.getUser().getName());
chatQueryDO.setQueryText(queryContext.getQueryText());
chatQueryDO.setAgentId(queryContext.getAgentId());
chatQueryDO.setQueryResult("");
try {
chatQueryDOMapper.insert(chatQueryDO);
@@ -135,31 +135,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
@Override
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryReq queryReq,
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
Long queryId = createChatQuery(parseResult, chatCtx, queryReq);
Long queryId = createChatQuery(parseResult, chatCtx, queryContext);
List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatCtx, queryReq, queryId, candidateParses, chatParseDOList);
getChatParseDO(chatCtx, queryContext, queryId, candidateParses, chatParseDOList);
if (!CollectionUtils.isEmpty(candidateParses)) {
chatParseMapper.batchSaveParseInfo(chatParseDOList);
}
return chatParseDOList;
}
@Override
public void updateChatParseInfo(List<ChatParseDO> chatParseDOS) {
for (ChatParseDO chatParseDO : chatParseDOS) {
chatParseMapper.updateParseInfo(chatParseDO);
}
}
public void getChatParseDO(ChatContext chatCtx, QueryReq queryReq, Long queryId,
public void getChatParseDO(ChatContext chatCtx, QueryContext queryContext, Long queryId,
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
chatParseDO.setQuestionId(queryId);
chatParseDO.setQueryText(queryReq.getQueryText());
chatParseDO.setQueryText(queryContext.getQueryText());
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
chatParseDO.setIsCandidate(1);
if (i == 0) {
@@ -167,7 +160,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
}
chatParseDO.setParseId(parses.get(i).getId());
chatParseDO.setCreateTime(new java.util.Date());
chatParseDO.setUserName(queryReq.getUser().getName());
chatParseDO.setUserName(queryContext.getUser().getName());
chatParseDOList.add(chatParseDO);
}
}

View File

@@ -1,13 +1,12 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
import com.tencent.supersonic.chat.server.service.SemanticService;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -29,7 +28,6 @@ public class EntityInfoProcessor implements ParseResultProcessor {
}
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
QueryReq queryReq = queryContext.getRequest();
selectedParses.forEach(parseInfo -> {
String queryMode = parseInfo.getQueryMode();
if (QueryManager.containsPluginQuery(queryMode)
@@ -38,7 +36,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
}
//1. set entity info
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryReq.getUser());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, queryContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);

View File

@@ -35,8 +35,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
@SneakyThrows
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getRequest().getQueryText(),
queryContext.getRequest().getAgentId());
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
queryContext.getAgentId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);

View File

@@ -1,11 +1,10 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.util.ContextUtils;
import java.util.List;
@@ -20,9 +19,8 @@ public class RespBuildProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
parseResp.setChatId(queryReq.getChatId());
parseResp.setQueryText(queryReq.getQueryText());
parseResp.setChatId(queryContext.getChatId());
parseResp.setQueryText(queryContext.getQueryText());
List<SemanticQuery> candidateQueries = queryContext.getCandidateQueries();
ChatService chatService = ContextUtils.getBean(ChatService.class);
if (candidateQueries.size() > 0) {
@@ -33,7 +31,7 @@ public class RespBuildProcessor implements ParseResultProcessor {
} else {
parseResp.setState(ParseResp.ParseState.FAILED);
}
chatService.batchAddParse(chatContext, queryReq, parseResp);
chatService.batchAddParse(chatContext, queryContext, parseResp);
}
}

View File

@@ -1,21 +1,20 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.core.query.QueryManager;
import com.tencent.supersonic.chat.core.query.SemanticQuery;
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* SqlInfoProcessor adds S2SQL to the parsing results so that
@@ -27,7 +26,6 @@ public class SqlInfoProcessor implements ParseResultProcessor {
@Override
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
QueryReq queryReq = queryContext.getRequest();
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(semanticQueries)) {
return;
@@ -35,26 +33,26 @@ public class SqlInfoProcessor implements ParseResultProcessor {
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
.collect(Collectors.toList());
long startTime = System.currentTimeMillis();
addSqlInfo(queryReq, selectedParses);
addSqlInfo(queryContext, selectedParses);
parseResp.getParseTimeCost().setSqlTime(System.currentTimeMillis() - startTime);
}
private void addSqlInfo(QueryReq queryReq, List<SemanticParseInfo> semanticParseInfos) {
private void addSqlInfo(QueryContext queryContext, List<SemanticParseInfo> semanticParseInfos) {
if (CollectionUtils.isEmpty(semanticParseInfos)) {
return;
}
semanticParseInfos.forEach(parseInfo -> {
addSqlInfo(queryReq, parseInfo);
addSqlInfo(queryContext, parseInfo);
});
}
private void addSqlInfo(QueryReq queryReq, SemanticParseInfo parseInfo) {
private void addSqlInfo(QueryContext queryContext, SemanticParseInfo parseInfo) {
SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode());
if (Objects.isNull(semanticQuery)) {
return;
}
semanticQuery.setParseInfo(parseInfo);
String explainSql = semanticQuery.explain(queryReq.getUser());
String explainSql = semanticQuery.explain(queryContext.getUser());
if (StringUtils.isBlank(explainSql)) {
return;
}

View File

@@ -4,11 +4,11 @@ import com.github.pagehelper.PageInfo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -46,7 +46,7 @@ public interface ChatService {
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult);
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult);
ChatQueryDO getLastQuery(long chatId);

View File

@@ -5,11 +5,11 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.core.pojo.ChatContext;
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
import com.tencent.supersonic.chat.core.pojo.QueryContext;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
@@ -211,9 +211,9 @@ public class ChatServiceImpl implements ChatService {
}
@Override
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryReq queryReq, ParseResp parseResult) {
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult) {
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryReq, parseResult, candidateParses);
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryContext, parseResult, candidateParses);
}
@Override

View File

@@ -90,6 +90,7 @@ import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Primary;
@@ -197,7 +198,6 @@ public class QueryServiceImpl implements QueryService {
List<Plugin> pluginList = pluginService.getPluginList();
QueryContext queryCtx = QueryContext.builder()
.request(queryReq)
.queryFilters(queryReq.getQueryFilters())
.semanticSchema(semanticSchema)
.candidateQueries(new ArrayList<>())
@@ -207,6 +207,7 @@ public class QueryServiceImpl implements QueryService {
.nameToPlugin(nameToPlugin)
.pluginList(pluginList)
.build();
BeanUtils.copyProperties(queryReq, queryCtx);
return queryCtx;
}

View File

@@ -42,6 +42,7 @@ import java.util.Set;
import java.util.stream.Collectors;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
@@ -91,10 +92,10 @@ public class SearchServiceImpl implements SearchService {
List<Term> originals = HanlpHelper.getTerms(queryText);
log.info("hanlp parse result: {}", originals);
MapperHelper mapperHelper = ContextUtils.getBean(MapperHelper.class);
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq, agentService.getAgent(agentId));
Set<Long> detectModelIds = mapperHelper.getModelIds(queryReq.getModelId(), agentService.getAgent(agentId));
QueryContext queryContext = new QueryContext();
queryContext.setRequest(queryReq);
BeanUtils.copyProperties(queryReq, queryContext);
Map<MatchText, List<HanlpMapResult>> regTextMap =
searchMatchStrategy.match(queryContext, originals, detectModelIds);
regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue()));