diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java index 6980425dd..a9fb0eba6 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java @@ -19,4 +19,6 @@ public class ChatParseReq { private QueryFilters queryFilters; private boolean saveAnswer = true; private boolean disableLLM = false; + private Long queryId; + private Integer parseId; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 609782830..fc4bac952 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.server.parser; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; @@ -17,10 +16,12 @@ import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryState; +import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.server.facade.service.ChatLayerService; import com.tencent.supersonic.headless.server.utils.ModelConfigHelper; import dev.langchain4j.data.message.AiMessage; @@ -74,37 +75,48 @@ public class NL2SQLParser implements ChatQueryParser { if (!parseContext.enableNL2SQL() || Objects.isNull(parseContext.getAgent())) { return; } - QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); - if (Objects.isNull(queryNLReq)) { + if (parseContext.enableFeedback()) { + processFeedback(parseContext); return; } - ParseResp parseResp = parseContext.getResponse(); - ChatParseReq parseReq = parseContext.getRequest(); - - if (!parseContext.getRequest().isDisableLLM() && queryNLReq.getText2SQLType().enableLLM()) { - processMultiTurn(parseContext); - addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); - parseResp.setUsedExemplars(queryNLReq.getDynamicExemplars()); - } + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); - ChatContext chatCtx = chatContextService.getOrCreateContext(parseReq.getChatId()); - if (chatCtx != null) { + ChatContext chatCtx = + chatContextService.getOrCreateContext(parseContext.getRequest().getChatId()); + if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) { queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); } - - ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); - ParseResp text2SqlParseResp = chatLayerService.parse(queryNLReq); - if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { - parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); + if (parseContext.enableLLM()) { + rewriteMultiTurn(parseContext, queryNLReq); + addDynamicExemplars(parseContext, queryNLReq); } - parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg()); - parseResp.setState(text2SqlParseResp.getState()); - parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); - parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg()); + + ParseResp parseResp = parseContext.getResponse(); + doParse(queryNLReq, parseResp); } - private void processMultiTurn(ParseContext parseContext) { + private void processFeedback(ParseContext parseContext) { + QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); + ParseResp parseResp = parseContext.getResponse(); + for (MapModeEnum mode : MapModeEnum.values()) { + queryNLReq.setMapModeEnum(mode); + doParse(queryNLReq, parseResp); + } + } + + private void doParse(QueryNLReq req, ParseResp resp) { + ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); + ParseResp text2SqlParseResp = chatLayerService.parse(req); + if (text2SqlParseResp.getState().equals(ParseResp.ParseState.COMPLETED)) { + resp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); + } + resp.setState(text2SqlParseResp.getState()); + resp.setParseTimeCost(text2SqlParseResp.getParseTimeCost()); + resp.setErrorMsg(text2SqlParseResp.getErrorMsg()); + } + + private void rewriteMultiTurn(ParseContext parseContext, QueryNLReq queryNLReq) { ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_MULTI_TURN); if (Objects.isNull(chatApp) || !chatApp.isEnable()) { return; @@ -112,7 +124,6 @@ public class NL2SQLParser implements ChatQueryParser { // derive mapping result of current question and parsing result of last question. ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); - QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); MapResp currentMapResult = chatLayerService.map(queryNLReq); List historyQueries = @@ -143,6 +154,7 @@ public class NL2SQLParser implements ChatQueryParser { String rewrittenQuery = response.content().text(); keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response); parseContext.getRequest().setQueryText(rewrittenQuery); + queryNLReq.setQueryText(rewrittenQuery); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); } @@ -185,15 +197,17 @@ public class NL2SQLParser implements ChatQueryParser { return contextualList; } - private void addDynamicExemplars(Integer agentId, QueryNLReq queryNLReq) { + private void addDynamicExemplars(ParseContext parseContext, QueryNLReq queryNLReq) { ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class); EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); - String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); + String memoryCollectionName = + embeddingConfig.getMemoryCollectionName(parseContext.getAgent().getId()); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); int exemplarRecallNumber = Integer.parseInt(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); List exemplars = exemplarManager.recallExemplars(memoryCollectionName, queryNLReq.getQueryText(), exemplarRecallNumber); queryNLReq.getDynamicExemplars().addAll(exemplars); + parseContext.getResponse().setUsedExemplars(exemplars); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java deleted file mode 100644 index 9df5123d6..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java +++ /dev/null @@ -1,15 +0,0 @@ -package com.tencent.supersonic.chat.server.parser; - -import com.tencent.supersonic.common.config.ParameterConfig; -import com.tencent.supersonic.common.pojo.Parameter; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Service; - -@Service("ChatQueryParserConfig") -@Slf4j -public class ParserConfig extends ParameterConfig { - - public static final Parameter PARSER_MULTI_TURN_ENABLE = - new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token", - "bool", "语义解析配置"); -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index 1eac60017..e05931856 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -2,14 +2,18 @@ package com.tencent.supersonic.chat.server.pojo; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import lombok.Data; +import java.util.Objects; + @Data public class ParseContext { private ChatParseReq request; private ParseResp response; private Agent agent; + private SemanticParseInfo selectedParseInfo; public ParseContext(ChatParseReq request) { this.request = request; @@ -17,9 +21,14 @@ public class ParseContext { } public boolean enableNL2SQL() { - if (agent == null) { - return false; - } return agent.containsDatasetTool(); } + + public boolean enableFeedback() { + return agent.enableFeedback() && Objects.isNull(request.getParseId()); + } + + public boolean enableLLM() { + return !(enableFeedback() || request.isDisableLLM()); + } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 8c786470f..c6f485085 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -168,6 +168,12 @@ public class ChatQueryServiceImpl implements ChatQueryService { ParseContext parseContext = new ParseContext(chatParseReq); Agent agent = agentService.getAgent(chatParseReq.getAgentId()); parseContext.setAgent(agent); + if (Objects.nonNull(chatParseReq.getQueryId()) + && Objects.nonNull(chatParseReq.getParseId())) { + SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatParseReq.getQueryId(), + chatParseReq.getParseId()); + parseContext.setSelectedParseInfo(parseInfo); + } return parseContext; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index dfa91a667..d743d5f6b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -8,16 +8,13 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; public class QueryReqConverter { public static QueryNLReq buildQueryNLReq(ParseContext parseContext) { - if (parseContext.getAgent() == null) { - return null; - } - QueryNLReq queryNLReq = new QueryNLReq(); BeanMapper.mapper(parseContext.getRequest(), queryNLReq); - queryNLReq.setText2SQLType(parseContext.getRequest().isDisableLLM() ? Text2SQLType.ONLY_RULE - : Text2SQLType.RULE_AND_LLM); + queryNLReq.setText2SQLType( + parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE); queryNLReq.setDataSetIds(parseContext.getAgent().getDataSetIds()); queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig()); + queryNLReq.setSelectedParseInfo(parseContext.getSelectedParseInfo()); return queryNLReq; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java index 96197527a..a6661af67 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java @@ -15,6 +15,10 @@ public class SchemaMapInfo { private final Map> dataSetElementMatches = new HashMap<>(); + public boolean isEmpty() { + return dataSetElementMatches.keySet().isEmpty(); + } + public Set getMatchedDataSetInfos() { return dataSetElementMatches.keySet(); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java index 11b7826d2..a0bbe1e7f 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryNLReq.java @@ -28,6 +28,7 @@ public class QueryNLReq extends SemanticQueryReq { private Map chatAppConfig; private List dynamicExemplars = Lists.newArrayList(); private SemanticParseInfo contextParseInfo; + private SemanticParseInfo selectedParseInfo; @Override public String toCustomizedString() { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index a5601f7fa..a37e4c1cd 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -4,6 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; @@ -15,6 +16,7 @@ import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; @Data @@ -35,6 +37,10 @@ public class ChatQueryContext { public ChatQueryContext(QueryNLReq request) { this.request = request; + SemanticParseInfo parseInfo = request.getSelectedParseInfo(); + if (Objects.nonNull(parseInfo) && Objects.nonNull(parseInfo.getDataSetId())) { + mapInfo.setMatchedElements(parseInfo.getDataSetId(), parseInfo.getElementMatches()); + } } public List getCandidateQueries() { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index 79c604789..22305b8d9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -15,6 +15,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.SqlEvaluation; import com.tencent.supersonic.headless.api.pojo.SqlInfo; +import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq; @@ -87,7 +88,11 @@ public class S2ChatLayerService implements ChatLayerService { public ParseResp parse(QueryNLReq queryNLReq) { ParseResp parseResult = new ParseResp(queryNLReq.getQueryText()); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq); - chatWorkflowEngine.start(queryCtx, parseResult); + if (queryCtx.getMapInfo().isEmpty()) { + chatWorkflowEngine.start(ChatWorkflowState.MAPPING, queryCtx, parseResult); + } else { + chatWorkflowEngine.start(ChatWorkflowState.PARSING, queryCtx, parseResult); + } return parseResult; } @@ -113,6 +118,7 @@ public class S2ChatLayerService implements ChatLayerService { Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); queryCtx.setSemanticSchema(semanticSchema); queryCtx.setModelIdToDataSetIds(modelIdToDataSetIds); + return queryCtx; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index a03a4b6d5..8308d9655 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -36,13 +36,14 @@ public class ChatWorkflowEngine { ComponentFactory.getSemanticCorrectors(); private final List resultProcessors = ComponentFactory.getResultProcessors(); - public void start(ChatQueryContext queryCtx, ParseResp parseResult) { - queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); + public void start(ChatWorkflowState initialState, ChatQueryContext queryCtx, + ParseResp parseResult) { + queryCtx.setChatWorkflowState(initialState); while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) { switch (queryCtx.getChatWorkflowState()) { case MAPPING: performMapping(queryCtx); - if (queryCtx.getMapInfo().getMatchedDataSetInfos().isEmpty()) { + if (queryCtx.getMapInfo().isEmpty()) { parseResult.setState(ParseResp.ParseState.FAILED); parseResult.setErrorMsg( "No semantic entities can be mapped against user question.");