From d6c5702b5aa8eba410d2aad22231a3f81b27e481 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Thu, 11 Jul 2024 20:33:32 +0800 Subject: [PATCH] (feature)(chat)Rewrite error message to make it more understandable to the user. #1320 --- .../server/executor/PlainTextExecutor.java | 2 +- .../chat/server/parser/NL2SQLParser.java | 116 +++++++++++------- .../headless/api/pojo/SchemaMapInfo.java | 4 - .../headless/api/pojo/response/ParseResp.java | 13 +- .../server/utils/ChatWorkflowEngine.java | 27 +++- .../tencent/supersonic/chat/MetricTest.java | 2 - 6 files changed, 97 insertions(+), 67 deletions(-) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index be849f7c6..f0bd97508 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -83,7 +83,7 @@ public class PlainTextExecutor implements ChatExecutor { private List getHistoryParseResult(int chatId, int multiNum) { ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) - .stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList()); + .stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList()); List contextualList = contextualParseInfoList.subList(0, Math.min(multiNum, contextualParseInfoList.size())); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 5ac876438..e4b3937f1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; -import com.tencent.supersonic.common.config.ModelConfig; import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ContextUtils; @@ -25,8 +24,6 @@ import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.output.Response; import dev.langchain4j.provider.ModelProvider; -import lombok.Builder; -import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,7 +31,9 @@ import org.springframework.util.CollectionUtils; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -47,36 +46,56 @@ public class NL2SQLParser implements ChatParser { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private static final String REWRITE_INSTRUCTION = "" - + "#Role: You are a data product manager experienced in data requirements.\n" + private static final String REWRITE_USER_QUESTION_INSTRUCTION = "" + + "#Role: You are a data product manager experienced in data requirements." + "#Task: Your will be provided with current and history questions asked by a user," + "along with their mapped schema elements(metric, dimension and value)," - + "please try understanding the semantics and rewrite a question.\n" + + "please try understanding the semantics and rewrite a question." + "#Rules: " - + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. " - + "2.ONLY respond with the rewritten question.\n" - + "#Current Question: %s\n" - + "#Current Mapped Schema: %s\n" - + "#History Question: %s\n" - + "#History Mapped Schema: %s\n" - + "#History SQL: %s\n" + + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges." + + "2.ONLY respond with the rewritten question." + + "#Current Question: {{current_question}}" + + "#Current Mapped Schema: {{current_schema}}" + + "#History Question: {{history_question}}" + + "#History Mapped Schema: {{history_schema}}" + + "#History SQL: {{history_sql}}" + "#Rewritten Question: "; + private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = "" + + "#Role: You are a data business partner who closely interacts with business people.\n" + + "#Task: Your will be provided with user input, system output and some examples, " + + "please respond shortly to teach user how to ask the right question, " + + "using `Examples` as references." + + "#Rules: ALWAYS use the same language as the `Input`.\n" + + "#Input: {{user_question}}\n" + + "#Output: {{system_message}}\n" + + "#Examples: {{examples}}\n" + + "#Response: "; + @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) { return; } + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( + chatParseContext.getAgent().getModelConfig()); - processMultiTurn(chatParseContext); + processMultiTurn(chatLanguageModel, chatParseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq); - if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { + if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); + } else { + parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel, + chatParseContext.getQueryText(), + text2SqlParseResp.getErrorMsg(), + queryNLReq.getDynamicExemplars(), + chatParseContext.getAgent().getExamples())); } + parseResp.setState(text2SqlParseResp.getState()); parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); formatParseResult(parseResp); } @@ -136,7 +155,7 @@ public class NL2SQLParser implements ChatParser { parseInfo.setTextInfo(textBuilder.toString()); } - private void processMultiTurn(ChatParseContext chatParseContext) { + private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); @@ -162,30 +181,53 @@ public class NL2SQLParser implements ChatParser { String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL(); - String rewrittenQuery = rewriteQuery(RewriteContext.builder() - .curtQuestion(currentMapResult.getQueryText()) - .histQuestion(lastParseResult.getQueryText()) - .curtSchema(curtMapStr) - .histSchema(histMapStr) - .histSQL(histSQL) - .modelConfig(queryNLReq.getModelConfig()) - .build()); + + Map variables = new HashMap<>(); + variables.put("current_question", currentMapResult.getQueryText()); + variables.put("current_schema", curtMapStr); + variables.put("history_question", lastParseResult.getQueryText()); + variables.put("history_schema", histMapStr); + variables.put("history_sql", histSQL); + + Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables); + keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text()); + + Response response = chatLanguageModel.generate(prompt.toUserMessage()); + String rewrittenQuery = response.content().text(); + keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery); + chatParseContext.setQueryText(rewrittenQuery); log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); } - private String rewriteQuery(RewriteContext context) { - String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(), - context.getHistQuestion(), context.getHistSchema(), context.getHistSQL()); - Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); - keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); + private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion, + String errMsg, List similarExemplars, + List agentExamples) { + Map variables = new HashMap<>(); + variables.put("user_question", userQuestion); + variables.put("system_message", errMsg); - ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig()); + StringBuilder exampleStr = new StringBuilder(); + if (similarExemplars.size() > 0) { + similarExemplars.stream().forEach(e -> + exampleStr.append(String.format(" ", + e.getQuestion(), e.getDbSchema())) + ); + } else { + agentExamples.stream().forEach(e -> + exampleStr.append(String.format(" ", + e))); + } + variables.put("examples", exampleStr); + + Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables); + keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text()); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String result = response.content().text(); keyPipelineLog.info("NL2SQLParser modelResp:{}", result); + return response.content().text(); } @@ -217,7 +259,7 @@ public class NL2SQLParser implements ChatParser { private List getHistoryParseResult(int chatId, int multiNum) { ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); List contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) - .stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList()); + .stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList()); List contextualList = contextualParseInfoList.subList(0, Math.min(multiNum, contextualParseInfoList.size())); @@ -234,16 +276,4 @@ public class NL2SQLParser implements ChatParser { queryNLReq.getDynamicExemplars().addAll(exemplars); } - @Builder - @Data - public static class RewriteContext { - - private String curtQuestion; - private String histQuestion; - private String curtSchema; - private String histSchema; - private String histSQL; - private ModelConfig modelConfig; - } - } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java index 483b2a428..95d28a28f 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaMapInfo.java @@ -23,10 +23,6 @@ public class SchemaMapInfo { return dataSetElementMatches; } - public void setDataSetElementMatches(Map> dataSetElementMatches) { - this.dataSetElementMatches = dataSetElementMatches; - } - public void setMatchedElements(Long dataSet, List elementMatches) { dataSetElementMatches.put(dataSet, elementMatches); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java index 02d3ff62b..63acea7a7 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.response; import com.google.common.collect.Lists; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import lombok.Data; -import org.apache.commons.collections.CollectionUtils; import java.util.Comparator; import java.util.List; @@ -14,7 +13,8 @@ public class ParseResp { private Integer chatId; private String queryText; private Long queryId; - private ParseState state; + private ParseState state = ParseState.PENDING; + private String errorMsg; private List selectedParses = Lists.newArrayList(); private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp(); @@ -38,15 +38,6 @@ public class ParseResp { return selectedParses; } - public ParseState getState() { - if (CollectionUtils.isNotEmpty(selectedParses)) { - this.state = ParseResp.ParseState.COMPLETED; - } else { - this.state = ParseState.FAILED; - } - return this.state; - } - private void generateParseInfoId(List selectedParses) { for (int i = 0; i < selectedParses.size(); i++) { SemanticParseInfo parseInfo = selectedParses.get(i); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 5914cedd4..b8829aad4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -45,11 +45,23 @@ public class ChatWorkflowEngine { switch (queryCtx.getChatWorkflowState()) { case MAPPING: performMapping(queryCtx); - queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); + if (queryCtx.getMapInfo().getMatchedDataSetInfos().size() == 0) { + parseResult.setState(ParseResp.ParseState.FAILED); + parseResult.setErrorMsg("No semantic entities can be mapped against user question."); + queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); + } else { + queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); + } break; case PARSING: performParsing(queryCtx, chatCtx); - queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); + if (queryCtx.getCandidateQueries().size() == 0) { + parseResult.setState(ParseResp.ParseState.FAILED); + parseResult.setErrorMsg("No semantic queries can be parsed out."); + queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); + } else { + queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); + } break; case CORRECTING: performCorrecting(queryCtx); @@ -64,27 +76,30 @@ public class ChatWorkflowEngine { case PROCESSING: default: performProcessing(queryCtx, chatCtx, parseResult); + if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) { + parseResult.setState(ParseResp.ParseState.COMPLETED); + } queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); break; } } } - public void performMapping(ChatQueryContext queryCtx) { + private void performMapping(ChatQueryContext queryCtx) { if (Objects.isNull(queryCtx.getMapInfo()) || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { schemaMappers.forEach(mapper -> mapper.map(queryCtx)); } } - public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) { + private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) { semanticParsers.forEach(parser -> { parser.parse(queryCtx, chatCtx); log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); }); } - public void performCorrecting(ChatQueryContext queryCtx) { + private void performCorrecting(ChatQueryContext queryCtx) { List candidateQueries = queryCtx.getCandidateQueries(); if (CollectionUtils.isNotEmpty(candidateQueries)) { for (SemanticQuery semanticQuery : candidateQueries) { @@ -101,7 +116,7 @@ public class ChatWorkflowEngine { } } - public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { + private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { resultProcessors.forEach(processor -> { processor.process(parseResult, queryCtx, chatCtx); }); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index a6d714620..539dbe150 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -25,8 +25,6 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM; public class MetricTest extends BaseTest { - private int chatId = 10; - @Test public void testMetricFilter() throws Exception { QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);