(feature)(chat)Rewrite error message to make it more understandable to the user. #1320

This commit is contained in:
jerryjzhang
2024-07-11 20:33:32 +08:00
parent e0647dd990
commit d6c5702b5a
6 changed files with 97 additions and 67 deletions

View File

@@ -83,7 +83,7 @@ public class PlainTextExecutor implements ChatExecutor {
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) { private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList()); .stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0, List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size())); Math.min(multiNum, contextualParseInfoList.size()));

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar; import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -25,8 +24,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response; import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.ModelProvider; import dev.langchain4j.provider.ModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@@ -34,7 +31,9 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.Optional; import java.util.Optional;
import java.util.Set; import java.util.Set;
@@ -47,36 +46,56 @@ public class NL2SQLParser implements ChatParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String REWRITE_INSTRUCTION = "" private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
+ "#Role: You are a data product manager experienced in data requirements.\n" + "#Role: You are a data product manager experienced in data requirements."
+ "#Task: Your will be provided with current and history questions asked by a user," + "#Task: Your will be provided with current and history questions asked by a user,"
+ "along with their mapped schema elements(metric, dimension and value)," + "along with their mapped schema elements(metric, dimension and value),"
+ "please try understanding the semantics and rewrite a question.\n" + "please try understanding the semantics and rewrite a question."
+ "#Rules: " + "#Rules: "
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. " + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
+ "2.ONLY respond with the rewritten question.\n" + "2.ONLY respond with the rewritten question."
+ "#Current Question: %s\n" + "#Current Question: {{current_question}}"
+ "#Current Mapped Schema: %s\n" + "#Current Mapped Schema: {{current_schema}}"
+ "#History Question: %s\n" + "#History Question: {{history_question}}"
+ "#History Mapped Schema: %s\n" + "#History Mapped Schema: {{history_schema}}"
+ "#History SQL: %s\n" + "#History SQL: {{history_sql}}"
+ "#Rewritten Question: "; + "#Rewritten Question: ";
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
+ "#Role: You are a data business partner who closely interacts with business people.\n"
+ "#Task: Your will be provided with user input, system output and some examples, "
+ "please respond shortly to teach user how to ask the right question, "
+ "using `Examples` as references."
+ "#Rules: ALWAYS use the same language as the `Input`.\n"
+ "#Input: {{user_question}}\n"
+ "#Output: {{system_message}}\n"
+ "#Examples: {{examples}}\n"
+ "#Response: ";
@Override @Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) { if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
return; return;
} }
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
chatParseContext.getAgent().getModelConfig());
processMultiTurn(chatParseContext); processMultiTurn(chatLanguageModel, chatParseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq); addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
chatParseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(),
chatParseContext.getAgent().getExamples()));
} }
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
formatParseResult(parseResp); formatParseResult(parseResp);
} }
@@ -136,7 +155,7 @@ public class NL2SQLParser implements ChatParser {
parseInfo.setTextInfo(textBuilder.toString()); parseInfo.setTextInfo(textBuilder.toString());
} }
private void processMultiTurn(ChatParseContext chatParseContext) { private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
@@ -162,30 +181,53 @@ public class NL2SQLParser implements ChatParser {
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId)); String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches()); String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL(); String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
.curtQuestion(currentMapResult.getQueryText()) Map<String, Object> variables = new HashMap<>();
.histQuestion(lastParseResult.getQueryText()) variables.put("current_question", currentMapResult.getQueryText());
.curtSchema(curtMapStr) variables.put("current_schema", curtMapStr);
.histSchema(histMapStr) variables.put("history_question", lastParseResult.getQueryText());
.histSQL(histSQL) variables.put("history_schema", histMapStr);
.modelConfig(queryNLReq.getModelConfig()) variables.put("history_sql", histSQL);
.build());
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenQuery = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
chatParseContext.setQueryText(rewrittenQuery); chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery); lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
} }
private String rewriteQuery(RewriteContext context) { private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(), String errMsg, List<SqlExemplar> similarExemplars,
context.getHistQuestion(), context.getHistSchema(), context.getHistSQL()); List<String> agentExamples) {
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Map<String, Object> variables = new HashMap<>();
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr); variables.put("user_question", userQuestion);
variables.put("system_message", errMsg);
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig()); StringBuilder exampleStr = new StringBuilder();
if (similarExemplars.size() > 0) {
similarExemplars.stream().forEach(e ->
exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ",
e.getQuestion(), e.getDbSchema()))
);
} else {
agentExamples.stream().forEach(e ->
exampleStr.append(String.format("<Question:{%s}> ",
e)));
}
variables.put("examples", exampleStr);
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text(); String result = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", result); keyPipelineLog.info("NL2SQLParser modelResp:{}", result);
return response.content().text(); return response.content().text();
} }
@@ -217,7 +259,7 @@ public class NL2SQLParser implements ChatParser {
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) { private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class); ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId) List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList()); .stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0, List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size())); Math.min(multiNum, contextualParseInfoList.size()));
@@ -234,16 +276,4 @@ public class NL2SQLParser implements ChatParser {
queryNLReq.getDynamicExemplars().addAll(exemplars); queryNLReq.getDynamicExemplars().addAll(exemplars);
} }
@Builder
@Data
public static class RewriteContext {
private String curtQuestion;
private String histQuestion;
private String curtSchema;
private String histSchema;
private String histSQL;
private ModelConfig modelConfig;
}
} }

View File

@@ -23,10 +23,6 @@ public class SchemaMapInfo {
return dataSetElementMatches; return dataSetElementMatches;
} }
public void setDataSetElementMatches(Map<Long, List<SchemaElementMatch>> dataSetElementMatches) {
this.dataSetElementMatches = dataSetElementMatches;
}
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) { public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches); dataSetElementMatches.put(dataSet, elementMatches);
} }

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data; import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
@@ -14,7 +13,8 @@ public class ParseResp {
private Integer chatId; private Integer chatId;
private String queryText; private String queryText;
private Long queryId; private Long queryId;
private ParseState state; private ParseState state = ParseState.PENDING;
private String errorMsg;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList(); private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp(); private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
@@ -38,15 +38,6 @@ public class ParseResp {
return selectedParses; return selectedParses;
} }
public ParseState getState() {
if (CollectionUtils.isNotEmpty(selectedParses)) {
this.state = ParseResp.ParseState.COMPLETED;
} else {
this.state = ParseState.FAILED;
}
return this.state;
}
private void generateParseInfoId(List<SemanticParseInfo> selectedParses) { private void generateParseInfoId(List<SemanticParseInfo> selectedParses) {
for (int i = 0; i < selectedParses.size(); i++) { for (int i = 0; i < selectedParses.size(); i++) {
SemanticParseInfo parseInfo = selectedParses.get(i); SemanticParseInfo parseInfo = selectedParses.get(i);

View File

@@ -45,11 +45,23 @@ public class ChatWorkflowEngine {
switch (queryCtx.getChatWorkflowState()) { switch (queryCtx.getChatWorkflowState()) {
case MAPPING: case MAPPING:
performMapping(queryCtx); performMapping(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING); if (queryCtx.getMapInfo().getMatchedDataSetInfos().size() == 0) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic entities can be mapped against user question.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
}
break; break;
case PARSING: case PARSING:
performParsing(queryCtx, chatCtx); performParsing(queryCtx, chatCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); if (queryCtx.getCandidateQueries().size() == 0) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic queries can be parsed out.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
}
break; break;
case CORRECTING: case CORRECTING:
performCorrecting(queryCtx); performCorrecting(queryCtx);
@@ -64,27 +76,30 @@ public class ChatWorkflowEngine {
case PROCESSING: case PROCESSING:
default: default:
performProcessing(queryCtx, chatCtx, parseResult); performProcessing(queryCtx, chatCtx, parseResult);
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
parseResult.setState(ParseResp.ParseState.COMPLETED);
}
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break; break;
} }
} }
} }
public void performMapping(ChatQueryContext queryCtx) { private void performMapping(ChatQueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo()) if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) { || MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx)); schemaMappers.forEach(mapper -> mapper.map(queryCtx));
} }
} }
public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) { private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
semanticParsers.forEach(parser -> { semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx); parser.parse(queryCtx, chatCtx);
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx)); log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
}); });
} }
public void performCorrecting(ChatQueryContext queryCtx) { private void performCorrecting(ChatQueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries(); List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) { if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) { for (SemanticQuery semanticQuery : candidateQueries) {
@@ -101,7 +116,7 @@ public class ChatWorkflowEngine {
} }
} }
public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) { private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> { resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx); processor.process(parseResult, queryCtx, chatCtx);
}); });

View File

@@ -25,8 +25,6 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
public class MetricTest extends BaseTest { public class MetricTest extends BaseTest {
private int chatId = 10;
@Test @Test
public void testMetricFilter() throws Exception { public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId); QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);