(feature)(chat)Rewrite error message to make it more understandable to the user. #1320

This commit is contained in:
jerryjzhang
2024-07-11 20:33:32 +08:00
parent e0647dd990
commit d6c5702b5a
6 changed files with 97 additions and 67 deletions

View File

@@ -83,7 +83,7 @@ public class PlainTextExecutor implements ChatExecutor {
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.config.ModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -25,8 +24,6 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.ModelProvider;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -34,7 +31,9 @@ import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
@@ -47,36 +46,56 @@ public class NL2SQLParser implements ChatParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String REWRITE_INSTRUCTION = ""
+ "#Role: You are a data product manager experienced in data requirements.\n"
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
+ "#Role: You are a data product manager experienced in data requirements."
+ "#Task: Your will be provided with current and history questions asked by a user,"
+ "along with their mapped schema elements(metric, dimension and value),"
+ "please try understanding the semantics and rewrite a question.\n"
+ "please try understanding the semantics and rewrite a question."
+ "#Rules: "
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
+ "2.ONLY respond with the rewritten question.\n"
+ "#Current Question: %s\n"
+ "#Current Mapped Schema: %s\n"
+ "#History Question: %s\n"
+ "#History Mapped Schema: %s\n"
+ "#History SQL: %s\n"
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
+ "2.ONLY respond with the rewritten question."
+ "#Current Question: {{current_question}}"
+ "#Current Mapped Schema: {{current_schema}}"
+ "#History Question: {{history_question}}"
+ "#History Mapped Schema: {{history_schema}}"
+ "#History SQL: {{history_sql}}"
+ "#Rewritten Question: ";
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
+ "#Role: You are a data business partner who closely interacts with business people.\n"
+ "#Task: Your will be provided with user input, system output and some examples, "
+ "please respond shortly to teach user how to ask the right question, "
+ "using `Examples` as references."
+ "#Rules: ALWAYS use the same language as the `Input`.\n"
+ "#Input: {{user_question}}\n"
+ "#Output: {{system_message}}\n"
+ "#Examples: {{examples}}\n"
+ "#Response: ";
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
return;
}
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
chatParseContext.getAgent().getModelConfig());
processMultiTurn(chatParseContext);
processMultiTurn(chatLanguageModel, chatParseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
chatParseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(),
chatParseContext.getAgent().getExamples()));
}
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
formatParseResult(parseResp);
}
@@ -136,7 +155,7 @@ public class NL2SQLParser implements ChatParser {
parseInfo.setTextInfo(textBuilder.toString());
}
private void processMultiTurn(ChatParseContext chatParseContext) {
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
@@ -162,30 +181,53 @@ public class NL2SQLParser implements ChatParser {
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
.curtQuestion(currentMapResult.getQueryText())
.histQuestion(lastParseResult.getQueryText())
.curtSchema(curtMapStr)
.histSchema(histMapStr)
.histSQL(histSQL)
.modelConfig(queryNLReq.getModelConfig())
.build());
Map<String, Object> variables = new HashMap<>();
variables.put("current_question", currentMapResult.getQueryText());
variables.put("current_schema", curtMapStr);
variables.put("history_question", lastParseResult.getQueryText());
variables.put("history_schema", histMapStr);
variables.put("history_sql", histSQL);
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenQuery = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
chatParseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
}
private String rewriteQuery(RewriteContext context) {
String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(),
context.getHistQuestion(), context.getHistSchema(), context.getHistSQL());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
String errMsg, List<SqlExemplar> similarExemplars,
List<String> agentExamples) {
Map<String, Object> variables = new HashMap<>();
variables.put("user_question", userQuestion);
variables.put("system_message", errMsg);
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig());
StringBuilder exampleStr = new StringBuilder();
if (similarExemplars.size() > 0) {
similarExemplars.stream().forEach(e ->
exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ",
e.getQuestion(), e.getDbSchema()))
);
} else {
agentExamples.stream().forEach(e ->
exampleStr.append(String.format("<Question:{%s}> ",
e)));
}
variables.put("examples", exampleStr);
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", result);
return response.content().text();
}
@@ -217,7 +259,7 @@ public class NL2SQLParser implements ChatParser {
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
Math.min(multiNum, contextualParseInfoList.size()));
@@ -234,16 +276,4 @@ public class NL2SQLParser implements ChatParser {
queryNLReq.getDynamicExemplars().addAll(exemplars);
}
@Builder
@Data
public static class RewriteContext {
private String curtQuestion;
private String histQuestion;
private String curtSchema;
private String histSchema;
private String histSQL;
private ModelConfig modelConfig;
}
}

View File

@@ -23,10 +23,6 @@ public class SchemaMapInfo {
return dataSetElementMatches;
}
public void setDataSetElementMatches(Map<Long, List<SchemaElementMatch>> dataSetElementMatches) {
this.dataSetElementMatches = dataSetElementMatches;
}
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
dataSetElementMatches.put(dataSet, elementMatches);
}

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.response;
import com.google.common.collect.Lists;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
import org.apache.commons.collections.CollectionUtils;
import java.util.Comparator;
import java.util.List;
@@ -14,7 +13,8 @@ public class ParseResp {
private Integer chatId;
private String queryText;
private Long queryId;
private ParseState state;
private ParseState state = ParseState.PENDING;
private String errorMsg;
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
@@ -38,15 +38,6 @@ public class ParseResp {
return selectedParses;
}
public ParseState getState() {
if (CollectionUtils.isNotEmpty(selectedParses)) {
this.state = ParseResp.ParseState.COMPLETED;
} else {
this.state = ParseState.FAILED;
}
return this.state;
}
private void generateParseInfoId(List<SemanticParseInfo> selectedParses) {
for (int i = 0; i < selectedParses.size(); i++) {
SemanticParseInfo parseInfo = selectedParses.get(i);

View File

@@ -45,11 +45,23 @@ public class ChatWorkflowEngine {
switch (queryCtx.getChatWorkflowState()) {
case MAPPING:
performMapping(queryCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
if (queryCtx.getMapInfo().getMatchedDataSetInfos().size() == 0) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic entities can be mapped against user question.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
}
break;
case PARSING:
performParsing(queryCtx, chatCtx);
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
if (queryCtx.getCandidateQueries().size() == 0) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic queries can be parsed out.");
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
} else {
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
}
break;
case CORRECTING:
performCorrecting(queryCtx);
@@ -64,27 +76,30 @@ public class ChatWorkflowEngine {
case PROCESSING:
default:
performProcessing(queryCtx, chatCtx, parseResult);
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
parseResult.setState(ParseResp.ParseState.COMPLETED);
}
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
break;
}
}
}
public void performMapping(ChatQueryContext queryCtx) {
private void performMapping(ChatQueryContext queryCtx) {
if (Objects.isNull(queryCtx.getMapInfo())
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
}
}
public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
}
public void performCorrecting(ChatQueryContext queryCtx) {
private void performCorrecting(ChatQueryContext queryCtx) {
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
if (CollectionUtils.isNotEmpty(candidateQueries)) {
for (SemanticQuery semanticQuery : candidateQueries) {
@@ -101,7 +116,7 @@ public class ChatWorkflowEngine {
}
}
public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
});

View File

@@ -25,8 +25,6 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
public class MetricTest extends BaseTest {
private int chatId = 10;
@Test
public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);