[improvement][project]Remove unnecessary SchemaMapInfo from ParseContext.

This commit is contained in:
jerryjzhang
2024-10-27 15:14:06 +08:00
parent 1e3daffade
commit 397b527bc6
9 changed files with 44 additions and 46 deletions

View File

@@ -100,7 +100,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (!parseContext.isDisableLLM()) {
processMultiTurn(parseContext);
}
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
@@ -179,11 +179,11 @@ public class NL2SQLParser implements ChatQueryParser {
// derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
MapResp currentMapResult = chatLayerService.map(queryNLReq);
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
if (historyQueries.size() == 0) {
if (historyQueries.isEmpty()) {
return;
}
QueryResp lastQuery = historyQueries.get(0);
@@ -209,9 +209,6 @@ public class NL2SQLParser implements ChatQueryParser {
String rewrittenQuery = response.content().text();
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
parseContext.setQueryText(rewrittenQuery);
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo());
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
currentMapResult.getQueryText(), rewrittenQuery);
}

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -20,6 +21,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery;
@@ -193,8 +196,10 @@ public class PluginManager {
}
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, schemaMapInfo);
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
return Pair.of(false, Sets.newHashSet());
}
@@ -260,8 +265,8 @@ public class PluginManager {
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, SchemaMapInfo mapInfo) {
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
if (plugin.isContainsAllDataSet()) {
return Sets.newHashSet(plugin.getDefaultMode());
}

View File

@@ -7,15 +7,20 @@ import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import org.springframework.util.CollectionUtils;
import java.util.HashMap;
@@ -48,9 +53,12 @@ public abstract class PluginRecognizer {
if (plugin.isContainsAllDataSet()) {
dataSetIds = Sets.newHashSet(-1L);
}
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
for (Long dataSetId : dataSetIds) {
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
parseContext, pluginRecallResult.getDistance());
parseContext, schemaMapInfo, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(plugin.getType());
semanticParseInfo.setScore(pluginRecallResult.getScore());
parseResp.getSelectedParses().add(semanticParseInfo);
@@ -62,9 +70,8 @@ public abstract class PluginRecognizer {
}
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
ParseContext parseContext, double distance) {
List<SchemaElementMatch> schemaElementMatches =
parseContext.getMapInfo().getMatchedElements(dataSetId);
ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
QueryFilters queryFilters = parseContext.getQueryFilters();
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data;
@@ -14,7 +13,6 @@ public class ParseContext {
private Integer chatId;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private SchemaMapInfo mapInfo;
private boolean disableLLM = false;
public boolean enableNL2SQL() {

View File

@@ -38,7 +38,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
@@ -90,11 +89,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private ChatModelService chatModelService;
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
private List<ParseResultProcessor> parseResultProcessors =
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
private final List<ParseResultProcessor> parseResultProcessors =
ComponentFactory.getParseProcessors();
private List<ExecuteResultProcessor> executeResultProcessors =
private final List<ExecuteResultProcessor> executeResultProcessors =
ComponentFactory.getExecuteProcessors();
@Override
@@ -104,7 +103,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (!agent.enableSearch()) {
return Lists.newArrayList();
}
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
return chatLayerService.retrieve(queryNLReq);
}
@@ -113,13 +112,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp);
ParseContext parseContext = buildParseContext(chatParseReq);
supplyMapInfo(parseContext);
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
chatQueryParser.parse(parseContext, parseResp);
for (ChatQueryParser parser : chatQueryParsers) {
parser.parse(parseContext, parseResp);
}
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(parseContext, parseResp);
}
chatParseReq.setQueryText(parseContext.getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp);
chatManageService.updateParseCostTime(parseResp);
@@ -175,12 +175,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return parseContext;
}
private void supplyMapInfo(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp mapResp = chatLayerService.map(queryNLReq);
parseContext.setMapInfo(mapResp.getMapInfo());
}
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ExecuteContext executeContext = new ExecuteContext();
BeanMapper.mapper(chatExecuteReq, executeContext);
@@ -197,7 +191,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
Integer parseId = chatQueryDataReq.getParseId();
SemanticParseInfo parseInfo =
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
mergeParseInfo(parseInfo, chatQueryDataReq);
DataSetSchema dataSetSchema =
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
@@ -494,10 +488,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
});
}
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
ChatQueryDataReq queryData) {
private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo;
return;
}
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions());
@@ -515,7 +508,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
parseInfo.setDateInfo(queryData.getDateInfo());
}
parseInfo.setSqlInfo(new SqlInfo());
return parseInfo;
}
private void validFilter(Set<QueryFilter> filters) {

View File

@@ -12,11 +12,11 @@ import java.util.Objects;
public class QueryReqConverter {
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
return buildText2SqlQueryReq(parseContext, null);
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
return buildQueryNLReq(parseContext, null);
}
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) {
QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(parseContext, queryNLReq);
Agent agent = parseContext.getAgent();

View File

@@ -78,8 +78,7 @@ public class QueryTypeParser implements SemanticParser {
}
private static List<String> filterByTimeFields(List<String> whereFields) {
return whereFields.stream()
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList());
}

View File

@@ -89,7 +89,7 @@ public class S2ChatLayerService implements ChatLayerService {
public ParseResp parse(QueryNLReq queryNLReq) {
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
chatWorkflowEngine.execute(queryCtx, parseResult);
chatWorkflowEngine.start(queryCtx, parseResult);
return parseResult;
}

View File

@@ -36,7 +36,7 @@ public class ChatWorkflowEngine {
ComponentFactory.getSemanticCorrectors();
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
public void start(ChatQueryContext queryCtx, ParseResp parseResult) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
switch (queryCtx.getChatWorkflowState()) {
@@ -122,8 +122,8 @@ public class ChatWorkflowEngine {
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
}
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) {
List<SemanticParseInfo> semanticParseInfos = chatQueryContext.getCandidateQueries().stream()
private void performTranslating(ChatQueryContext queryCtx, ParseResp parseResult) {
List<SemanticParseInfo> semanticParseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
List<String> errorMsg = new ArrayList<>();
if (StringUtils.isNotBlank(parseResult.getErrorMsg())) {
@@ -140,7 +140,7 @@ public class ChatWorkflowEngine {
SemanticLayerService queryService =
ContextUtils.getBean(SemanticLayerService.class);
SemanticTranslateResp explain =
queryService.translate(semanticQueryReq, chatQueryContext.getUser());
queryService.translate(semanticQueryReq, queryCtx.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
if (StringUtils.isNotBlank(explain.getErrMsg())) {
errorMsg.add(explain.getErrMsg());