mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(feature)(chat)Rewrite error message to make it more understandable to the user. #1320
This commit is contained in:
@@ -83,7 +83,7 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
||||
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
|
||||
@@ -6,7 +6,6 @@ import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.config.ModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -25,8 +24,6 @@ import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -34,7 +31,9 @@ import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
@@ -47,36 +46,56 @@ public class NL2SQLParser implements ChatParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
private static final String REWRITE_INSTRUCTION = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements.\n"
|
||||
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements."
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||
+ "please try understanding the semantics and rewrite a question.\n"
|
||||
+ "please try understanding the semantics and rewrite a question."
|
||||
+ "#Rules: "
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges. "
|
||||
+ "2.ONLY respond with the rewritten question.\n"
|
||||
+ "#Current Question: %s\n"
|
||||
+ "#Current Mapped Schema: %s\n"
|
||||
+ "#History Question: %s\n"
|
||||
+ "#History Mapped Schema: %s\n"
|
||||
+ "#History SQL: %s\n"
|
||||
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
|
||||
+ "2.ONLY respond with the rewritten question."
|
||||
+ "#Current Question: {{current_question}}"
|
||||
+ "#Current Mapped Schema: {{current_schema}}"
|
||||
+ "#History Question: {{history_question}}"
|
||||
+ "#History Mapped Schema: {{history_schema}}"
|
||||
+ "#History SQL: {{history_sql}}"
|
||||
+ "#Rewritten Question: ";
|
||||
|
||||
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
||||
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||
+ "please respond shortly to teach user how to ask the right question, "
|
||||
+ "using `Examples` as references."
|
||||
+ "#Rules: ALWAYS use the same language as the `Input`.\n"
|
||||
+ "#Input: {{user_question}}\n"
|
||||
+ "#Output: {{system_message}}\n"
|
||||
+ "#Examples: {{examples}}\n"
|
||||
+ "#Response: ";
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
chatParseContext.getAgent().getModelConfig());
|
||||
|
||||
processMultiTurn(chatParseContext);
|
||||
processMultiTurn(chatLanguageModel, chatParseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
|
||||
chatParseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(),
|
||||
queryNLReq.getDynamicExemplars(),
|
||||
chatParseContext.getAgent().getExamples()));
|
||||
}
|
||||
parseResp.setState(text2SqlParseResp.getState());
|
||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||
formatParseResult(parseResp);
|
||||
}
|
||||
@@ -136,7 +155,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
parseInfo.setTextInfo(textBuilder.toString());
|
||||
}
|
||||
|
||||
private void processMultiTurn(ChatParseContext chatParseContext) {
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
@@ -162,30 +181,53 @@ public class NL2SQLParser implements ChatParser {
|
||||
String curtMapStr = generateSchemaPrompt(currentMapResult.getMapInfo().getMatchedElements(dataId));
|
||||
String histMapStr = generateSchemaPrompt(lastParseResult.getSelectedParses().get(0).getElementMatches());
|
||||
String histSQL = lastParseResult.getSelectedParses().get(0).getSqlInfo().getCorrectedS2SQL();
|
||||
String rewrittenQuery = rewriteQuery(RewriteContext.builder()
|
||||
.curtQuestion(currentMapResult.getQueryText())
|
||||
.histQuestion(lastParseResult.getQueryText())
|
||||
.curtSchema(curtMapStr)
|
||||
.histSchema(histMapStr)
|
||||
.histSQL(histSQL)
|
||||
.modelConfig(queryNLReq.getModelConfig())
|
||||
.build());
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("current_question", currentMapResult.getQueryText());
|
||||
variables.put("current_schema", curtMapStr);
|
||||
variables.put("history_question", lastParseResult.getQueryText());
|
||||
variables.put("history_schema", histMapStr);
|
||||
variables.put("history_sql", histSQL);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String rewrittenQuery = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
|
||||
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
private String rewriteQuery(RewriteContext context) {
|
||||
String promptStr = String.format(REWRITE_INSTRUCTION, context.getCurtQuestion(), context.getCurtSchema(),
|
||||
context.getHistQuestion(), context.getHistSchema(), context.getHistSQL());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", promptStr);
|
||||
private String rewriteErrorMessage(ChatLanguageModel chatLanguageModel, String userQuestion,
|
||||
String errMsg, List<SqlExemplar> similarExemplars,
|
||||
List<String> agentExamples) {
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("user_question", userQuestion);
|
||||
variables.put("system_message", errMsg);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(context.getModelConfig());
|
||||
StringBuilder exampleStr = new StringBuilder();
|
||||
if (similarExemplars.size() > 0) {
|
||||
similarExemplars.stream().forEach(e ->
|
||||
exampleStr.append(String.format("<Question:{%s},Schema:{%s}> ",
|
||||
e.getQuestion(), e.getDbSchema()))
|
||||
);
|
||||
} else {
|
||||
agentExamples.stream().forEach(e ->
|
||||
exampleStr.append(String.format("<Question:{%s}> ",
|
||||
e)));
|
||||
}
|
||||
variables.put("examples", exampleStr);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
String result = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", result);
|
||||
|
||||
return response.content().text();
|
||||
}
|
||||
|
||||
@@ -217,7 +259,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
private List<ParseResp> getHistoryParseResult(int chatId, int multiNum) {
|
||||
ChatQueryRepository chatQueryRepository = ContextUtils.getBean(ChatQueryRepository.class);
|
||||
List<ParseResp> contextualParseInfoList = chatQueryRepository.getContextualParseInfo(chatId)
|
||||
.stream().filter(p -> p.getState() != ParseResp.ParseState.FAILED).collect(Collectors.toList());
|
||||
.stream().filter(p -> p.getState() == ParseResp.ParseState.COMPLETED).collect(Collectors.toList());
|
||||
|
||||
List<ParseResp> contextualList = contextualParseInfoList.subList(0,
|
||||
Math.min(multiNum, contextualParseInfoList.size()));
|
||||
@@ -234,16 +276,4 @@ public class NL2SQLParser implements ChatParser {
|
||||
queryNLReq.getDynamicExemplars().addAll(exemplars);
|
||||
}
|
||||
|
||||
@Builder
|
||||
@Data
|
||||
public static class RewriteContext {
|
||||
|
||||
private String curtQuestion;
|
||||
private String histQuestion;
|
||||
private String curtSchema;
|
||||
private String histSchema;
|
||||
private String histSQL;
|
||||
private ModelConfig modelConfig;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -23,10 +23,6 @@ public class SchemaMapInfo {
|
||||
return dataSetElementMatches;
|
||||
}
|
||||
|
||||
public void setDataSetElementMatches(Map<Long, List<SchemaElementMatch>> dataSetElementMatches) {
|
||||
this.dataSetElementMatches = dataSetElementMatches;
|
||||
}
|
||||
|
||||
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
|
||||
dataSetElementMatches.put(dataSet, elementMatches);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.headless.api.pojo.response;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
@@ -14,7 +13,8 @@ public class ParseResp {
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId;
|
||||
private ParseState state;
|
||||
private ParseState state = ParseState.PENDING;
|
||||
private String errorMsg;
|
||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
|
||||
|
||||
@@ -38,15 +38,6 @@ public class ParseResp {
|
||||
return selectedParses;
|
||||
}
|
||||
|
||||
public ParseState getState() {
|
||||
if (CollectionUtils.isNotEmpty(selectedParses)) {
|
||||
this.state = ParseResp.ParseState.COMPLETED;
|
||||
} else {
|
||||
this.state = ParseState.FAILED;
|
||||
}
|
||||
return this.state;
|
||||
}
|
||||
|
||||
private void generateParseInfoId(List<SemanticParseInfo> selectedParses) {
|
||||
for (int i = 0; i < selectedParses.size(); i++) {
|
||||
SemanticParseInfo parseInfo = selectedParses.get(i);
|
||||
|
||||
@@ -45,11 +45,23 @@ public class ChatWorkflowEngine {
|
||||
switch (queryCtx.getChatWorkflowState()) {
|
||||
case MAPPING:
|
||||
performMapping(queryCtx);
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
|
||||
if (queryCtx.getMapInfo().getMatchedDataSetInfos().size() == 0) {
|
||||
parseResult.setState(ParseResp.ParseState.FAILED);
|
||||
parseResult.setErrorMsg("No semantic entities can be mapped against user question.");
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
} else {
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.PARSING);
|
||||
}
|
||||
break;
|
||||
case PARSING:
|
||||
performParsing(queryCtx, chatCtx);
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
||||
if (queryCtx.getCandidateQueries().size() == 0) {
|
||||
parseResult.setState(ParseResp.ParseState.FAILED);
|
||||
parseResult.setErrorMsg("No semantic queries can be parsed out.");
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
} else {
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING);
|
||||
}
|
||||
break;
|
||||
case CORRECTING:
|
||||
performCorrecting(queryCtx);
|
||||
@@ -64,27 +76,30 @@ public class ChatWorkflowEngine {
|
||||
case PROCESSING:
|
||||
default:
|
||||
performProcessing(queryCtx, chatCtx, parseResult);
|
||||
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
|
||||
parseResult.setState(ParseResp.ParseState.COMPLETED);
|
||||
}
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void performMapping(ChatQueryContext queryCtx) {
|
||||
private void performMapping(ChatQueryContext queryCtx) {
|
||||
if (Objects.isNull(queryCtx.getMapInfo())
|
||||
|| MapUtils.isEmpty(queryCtx.getMapInfo().getDataSetElementMatches())) {
|
||||
schemaMappers.forEach(mapper -> mapper.map(queryCtx));
|
||||
}
|
||||
}
|
||||
|
||||
public void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||
private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||
semanticParsers.forEach(parser -> {
|
||||
parser.parse(queryCtx, chatCtx);
|
||||
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||
});
|
||||
}
|
||||
|
||||
public void performCorrecting(ChatQueryContext queryCtx) {
|
||||
private void performCorrecting(ChatQueryContext queryCtx) {
|
||||
List<SemanticQuery> candidateQueries = queryCtx.getCandidateQueries();
|
||||
if (CollectionUtils.isNotEmpty(candidateQueries)) {
|
||||
for (SemanticQuery semanticQuery : candidateQueries) {
|
||||
@@ -101,7 +116,7 @@ public class ChatWorkflowEngine {
|
||||
}
|
||||
}
|
||||
|
||||
public void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
||||
private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
||||
resultProcessors.forEach(processor -> {
|
||||
processor.process(parseResult, queryCtx, chatCtx);
|
||||
});
|
||||
|
||||
@@ -25,8 +25,6 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
|
||||
|
||||
public class MetricTest extends BaseTest {
|
||||
|
||||
private int chatId = 10;
|
||||
|
||||
@Test
|
||||
public void testMetricFilter() throws Exception {
|
||||
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
|
||||
|
||||
Reference in New Issue
Block a user