From 397b527bc6243789f420758ce14012bd318d428d Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sun, 27 Oct 2024 15:14:06 +0800 Subject: [PATCH] [improvement][project]Remove unnecessary SchemaMapInfo from `ParseContext`. --- .../chat/server/parser/NL2SQLParser.java | 9 ++---- .../chat/server/plugin/PluginManager.java | 13 +++++--- .../plugin/recognize/PluginRecognizer.java | 15 ++++++--- .../chat/server/pojo/ParseContext.java | 2 -- .../service/impl/ChatQueryServiceImpl.java | 32 +++++++------------ .../chat/server/util/QueryReqConverter.java | 6 ++-- .../headless/chat/parser/QueryTypeParser.java | 3 +- .../service/impl/S2ChatLayerService.java | 2 +- .../server/utils/ChatWorkflowEngine.java | 8 ++--- 9 files changed, 44 insertions(+), 46 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 486bef058..15c6998f5 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -100,7 +100,7 @@ public class NL2SQLParser implements ChatQueryParser { if (!parseContext.isDisableLLM()) { processMultiTurn(parseContext); } - QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx); + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx); addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); @@ -179,11 +179,11 @@ public class NL2SQLParser implements ChatQueryParser { // derive mapping result of current question and parsing result of last question. ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); - QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); MapResp currentMapResult = chatLayerService.map(queryNLReq); List historyQueries = getHistoryQueries(parseContext.getChatId(), 1); - if (historyQueries.size() == 0) { + if (historyQueries.isEmpty()) { return; } QueryResp lastQuery = historyQueries.get(0); @@ -209,9 +209,6 @@ public class NL2SQLParser implements ChatQueryParser { String rewrittenQuery = response.content().text(); keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response); parseContext.setQueryText(rewrittenQuery); - QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); - MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq); - parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo()); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java index 35d3598db..eecc6202e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java @@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent; import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent; import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.service.PluginService; +import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.util.ContextUtils; @@ -20,6 +21,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; +import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.RetrieveQuery; @@ -193,8 +196,10 @@ public class PluginManager { } public static Pair> resolve(ChatPlugin plugin, ParseContext parseContext) { - SchemaMapInfo schemaMapInfo = parseContext.getMapInfo(); - Set pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext); + ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); + SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo(); + Set pluginMatchedDataSet = getPluginMatchedDataSet(plugin, schemaMapInfo); if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) { return Pair.of(false, Sets.newHashSet()); } @@ -260,8 +265,8 @@ public class PluginManager { .collect(Collectors.toList()); } - private static Set getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) { - Set matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos(); + private static Set getPluginMatchedDataSet(ChatPlugin plugin, SchemaMapInfo mapInfo) { + Set matchedDataSets = mapInfo.getMatchedDataSetInfos(); if (plugin.isContainsAllDataSet()) { return Sets.newHashSet(plugin.getDefaultMode()); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index c4d2935e1..86db67436 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -7,15 +7,20 @@ import com.tencent.supersonic.chat.server.plugin.PluginManager; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; import com.tencent.supersonic.chat.server.pojo.ParseContext; +import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import org.springframework.util.CollectionUtils; import java.util.HashMap; @@ -48,9 +53,12 @@ public abstract class PluginRecognizer { if (plugin.isContainsAllDataSet()) { dataSetIds = Sets.newHashSet(-1L); } + ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); + SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo(); for (Long dataSetId : dataSetIds) { SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin, - parseContext, pluginRecallResult.getDistance()); + parseContext, schemaMapInfo, pluginRecallResult.getDistance()); semanticParseInfo.setQueryMode(plugin.getType()); semanticParseInfo.setScore(pluginRecallResult.getScore()); parseResp.getSelectedParses().add(semanticParseInfo); @@ -62,9 +70,8 @@ public abstract class PluginRecognizer { } protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin, - ParseContext parseContext, double distance) { - List schemaElementMatches = - parseContext.getMapInfo().getMatchedElements(dataSetId); + ParseContext parseContext, SchemaMapInfo mapInfo, double distance) { + List schemaElementMatches = mapInfo.getMatchedElements(dataSetId); QueryFilters queryFilters = parseContext.getQueryFilters(); if (schemaElementMatches == null) { schemaElementMatches = Lists.newArrayList(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index 9b7188df8..c6bf84011 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.pojo; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.common.pojo.User; -import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import lombok.Data; @@ -14,7 +13,6 @@ public class ParseContext { private Integer chatId; private QueryFilters queryFilters; private boolean saveAnswer = true; - private SchemaMapInfo mapInfo; private boolean disableLLM = false; public boolean enableNL2SQL() { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 56d251c5d..401c581a0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -38,7 +38,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.SearchResult; @@ -90,11 +89,11 @@ public class ChatQueryServiceImpl implements ChatQueryService { @Autowired private ChatModelService chatModelService; - private List chatQueryParsers = ComponentFactory.getChatParsers(); - private List chatQueryExecutors = ComponentFactory.getChatExecutors(); - private List parseResultProcessors = + private final List chatQueryParsers = ComponentFactory.getChatParsers(); + private final List chatQueryExecutors = ComponentFactory.getChatExecutors(); + private final List parseResultProcessors = ComponentFactory.getParseProcessors(); - private List executeResultProcessors = + private final List executeResultProcessors = ComponentFactory.getExecuteProcessors(); @Override @@ -104,7 +103,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { if (!agent.enableSearch()) { return Lists.newArrayList(); } - QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); return chatLayerService.retrieve(queryNLReq); } @@ -113,13 +112,14 @@ public class ChatQueryServiceImpl implements ChatQueryService { ParseResp parseResp = new ParseResp(chatParseReq.getQueryText()); chatManageService.createChatQuery(chatParseReq, parseResp); ParseContext parseContext = buildParseContext(chatParseReq); - supplyMapInfo(parseContext); - for (ChatQueryParser chatQueryParser : chatQueryParsers) { - chatQueryParser.parse(parseContext, parseResp); + + for (ChatQueryParser parser : chatQueryParsers) { + parser.parse(parseContext, parseResp); } for (ParseResultProcessor processor : parseResultProcessors) { processor.process(parseContext, parseResp); } + chatParseReq.setQueryText(parseContext.getQueryText()); chatManageService.batchAddParse(chatParseReq, parseResp); chatManageService.updateParseCostTime(parseResp); @@ -175,12 +175,6 @@ public class ChatQueryServiceImpl implements ChatQueryService { return parseContext; } - private void supplyMapInfo(ParseContext parseContext) { - QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); - MapResp mapResp = chatLayerService.map(queryNLReq); - parseContext.setMapInfo(mapResp.getMapInfo()); - } - private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { ExecuteContext executeContext = new ExecuteContext(); BeanMapper.mapper(chatExecuteReq, executeContext); @@ -197,7 +191,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { Integer parseId = chatQueryDataReq.getParseId(); SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId); - parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq); + mergeParseInfo(parseInfo, chatQueryDataReq); DataSetSchema dataSetSchema = semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); @@ -494,10 +488,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { }); } - private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo, - ChatQueryDataReq queryData) { + private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) { if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { - return parseInfo; + return; } if (!CollectionUtils.isEmpty(queryData.getDimensions())) { parseInfo.setDimensions(queryData.getDimensions()); @@ -515,7 +508,6 @@ public class ChatQueryServiceImpl implements ChatQueryService { parseInfo.setDateInfo(queryData.getDateInfo()); } parseInfo.setSqlInfo(new SqlInfo()); - return parseInfo; } private void validFilter(Set filters) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index 688784ff9..c75871f83 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -12,11 +12,11 @@ import java.util.Objects; public class QueryReqConverter { - public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) { - return buildText2SqlQueryReq(parseContext, null); + public static QueryNLReq buildQueryNLReq(ParseContext parseContext) { + return buildQueryNLReq(parseContext, null); } - public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) { + public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) { QueryNLReq queryNLReq = new QueryNLReq(); BeanMapper.mapper(parseContext, queryNLReq); Agent agent = parseContext.getAgent(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index 60d3af5ce..c6b2db543 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -78,8 +78,7 @@ public class QueryTypeParser implements SemanticParser { } private static List filterByTimeFields(List whereFields) { - return whereFields.stream() - .filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) + return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) .collect(Collectors.toList()); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index b9edcd3e4..d9c0e942f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -89,7 +89,7 @@ public class S2ChatLayerService implements ChatLayerService { public ParseResp parse(QueryNLReq queryNLReq) { ParseResp parseResult = new ParseResp(queryNLReq.getQueryText()); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq); - chatWorkflowEngine.execute(queryCtx, parseResult); + chatWorkflowEngine.start(queryCtx, parseResult); return parseResult; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 1e987ad25..84c497125 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -36,7 +36,7 @@ public class ChatWorkflowEngine { ComponentFactory.getSemanticCorrectors(); private final List resultProcessors = ComponentFactory.getResultProcessors(); - public void execute(ChatQueryContext queryCtx, ParseResp parseResult) { + public void start(ChatQueryContext queryCtx, ParseResp parseResult) { queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) { switch (queryCtx.getChatWorkflowState()) { @@ -122,8 +122,8 @@ public class ChatWorkflowEngine { resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx)); } - private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) { - List semanticParseInfos = chatQueryContext.getCandidateQueries().stream() + private void performTranslating(ChatQueryContext queryCtx, ParseResp parseResult) { + List semanticParseInfos = queryCtx.getCandidateQueries().stream() .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); List errorMsg = new ArrayList<>(); if (StringUtils.isNotBlank(parseResult.getErrorMsg())) { @@ -140,7 +140,7 @@ public class ChatWorkflowEngine { SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); SemanticTranslateResp explain = - queryService.translate(semanticQueryReq, chatQueryContext.getUser()); + queryService.translate(semanticQueryReq, queryCtx.getUser()); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); if (StringUtils.isNotBlank(explain.getErrMsg())) { errorMsg.add(explain.getErrMsg());