[improvement][chat]Move generation of semantic text info and rewrite of error message to dedicated ResultProcessor.

This commit is contained in:
jerryjzhang
2024-10-27 22:34:16 +08:00
parent bb363a0286
commit b69ee81d58
13 changed files with 165 additions and 124 deletions

View File

@@ -42,7 +42,11 @@ public class Agent extends RecordInfo {
}
public boolean enableSearch() {
return enableSearch != null && enableSearch == 1;
return enableSearch == 1;
}
public boolean enableFeedback() {
return enableFeedback == 1;
}
public boolean enableMemoryReview() {

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
@@ -15,11 +14,9 @@ import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -35,7 +32,6 @@ import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.ArrayList;
import java.util.Collections;
@@ -43,8 +39,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
@@ -68,26 +62,11 @@ public class NL2SQLParser implements ChatQueryParser {
+ "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}"
+ "#Rewritten Question: ";
public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE";
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
+ "#Role: You are a data business partner who closely interacts with business people.\n"
+ "#Task: Your will be provided with user input, system output and some examples, "
+ "please respond shortly to teach user how to ask the right question, "
+ "by using `Examples` as references."
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
+ "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
+ "#Examples: {{examples}}\n" + "#Response: ";
public NL2SQLParser() {
ChatAppManager.register(APP_KEY_MULTI_TURN,
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
.appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false)
.build());
ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
.appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
.enable(false).build());
}
@Override
@@ -102,8 +81,10 @@ public class NL2SQLParser implements ChatQueryParser {
ParseResp parseResp = parseContext.getResponse();
ChatParseReq parseReq = parseContext.getRequest();
if (!parseContext.getRequest().isDisableLLM()) {
if (!parseContext.getRequest().isDisableLLM() && queryNLReq.getText2SQLType().enableLLM()) {
processMultiTurn(parseContext);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
parseResp.setUsedExemplars(queryNLReq.getDynamicExemplars());
}
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
@@ -111,64 +92,16 @@ public class NL2SQLParser implements ChatQueryParser {
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
ParseResp text2SqlParseResp = chatLayerService.parse(queryNLReq);
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
if (!parseReq.isDisableLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
}
}
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
formatParseResult(parseResp);
}
private void formatParseResult(ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
for (SemanticParseInfo parseInfo : selectedParses) {
formatParseInfo(parseInfo);
}
}
private void formatParseInfo(SemanticParseInfo parseInfo) {
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
formatNL2SQLParseInfo(parseInfo);
}
}
private void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
StringBuilder textBuilder = new StringBuilder();
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst();
metric.ifPresent(schemaElement -> textBuilder.append("**指标:** ")
.append(schemaElement.getName()).append(" "));
List<String> dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName)
.filter(Objects::nonNull).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dimensionNames)) {
textBuilder.append("**维度:** ").append(String.join(",", dimensionNames));
}
textBuilder.append("\n\n**筛选条件:** \n");
if (parseInfo.getDateInfo() != null) {
textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate())
.append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
}
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())
|| CollectionUtils.isEmpty(parseInfo.getMetricFilters())) {
Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters();
queryFilters.addAll(parseInfo.getMetricFilters());
for (QueryFilter queryFilter : queryFilters) {
textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ")
.append(queryFilter.getOperator().getValue()).append(" ")
.append(queryFilter.getValue()).append(" ");
}
}
parseInfo.setTextInfo(textBuilder.toString());
}
private void processMultiTurn(ParseContext parseContext) {
@@ -214,35 +147,6 @@ public class NL2SQLParser implements ChatQueryParser {
currentMapResult.getQueryText(), rewrittenQuery);
}
private String rewriteErrorMessage(ParseContext parseContext, String errMsg,
List<Text2SQLExemplar> similarExemplars) {
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
return errMsg;
}
Map<String, Object> variables = new HashMap<>();
variables.put("user_question", parseContext.getRequest().getQueryText());
variables.put("system_message", errMsg);
StringBuilder exampleStr = new StringBuilder();
similarExemplars.forEach(e -> exampleStr.append(
String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
parseContext.getAgent().getExamples()
.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
variables.put("examples", exampleStr);
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenMsg = response.content().text();
keyPipelineLog.info("ErrorRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
return rewrittenMsg;
}
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
List<String> metrics = new ArrayList<>();
List<String> dimensions = new ArrayList<>();

View File

@@ -0,0 +1,72 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.server.utils.ModelConfigHelper;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.provider.ModelProvider;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
public class ErrorMessageProcessor implements ParseResultProcessor {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE";
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
+ "#Role: You are a data business partner who closely interacts with business people.\n"
+ "#Task: Your will be provided with user input, system output and some examples, "
+ "please respond shortly to teach user how to ask the right question, "
+ "by using `Examples` as references."
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
+ "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
+ "#Examples: {{examples}}\n" + "#Response: ";
public ErrorMessageProcessor() {
ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
.appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
.enable(false).build());
}
@Override
public void process(ParseContext parseContext) {
String errMsg = parseContext.getResponse().getErrorMsg();
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
if (StringUtils.isBlank(errMsg) || Objects.isNull(chatApp) || !chatApp.isEnable()) {
return;
}
Map<String, Object> variables = new HashMap<>();
variables.put("user_question", parseContext.getRequest().getQueryText());
variables.put("system_message", errMsg);
StringBuilder exampleStr = new StringBuilder();
parseContext.getResponse().getUsedExemplars().forEach(e -> exampleStr.append(
String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
parseContext.getAgent().getExamples()
.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
variables.put("examples", exampleStr);
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenMsg = response.content().text();
parseContext.getResponse().setErrorMsg(rewrittenMsg);
keyPipelineLog.info("ErrorMessageProcessor modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
rewrittenMsg);
}
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
public class TextInfoProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext) {
parseContext.getResponse().getSelectedParses().forEach(p -> {
if (!PluginQueryManager.isPluginQuery(p.getQueryMode())) {
formatNL2SQLParseInfo(p);
}
});
}
private static void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
StringBuilder textBuilder = new StringBuilder();
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst();
metric.ifPresent(schemaElement -> textBuilder.append("**指标:** ")
.append(schemaElement.getName()).append(" "));
List<String> dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName)
.filter(Objects::nonNull).collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dimensionNames)) {
textBuilder.append("**维度:** ").append(String.join(",", dimensionNames));
}
textBuilder.append("\n\n**筛选条件:** \n");
if (parseInfo.getDateInfo() != null) {
textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate())
.append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
}
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())
|| CollectionUtils.isEmpty(parseInfo.getMetricFilters())) {
Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters();
queryFilters.addAll(parseInfo.getMetricFilters());
for (QueryFilter queryFilter : queryFilters) {
textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ")
.append(queryFilter.getOperator().getValue()).append(" ")
.append(queryFilter.getValue()).append(" ");
}
}
parseInfo.setTextInfo(textBuilder.toString());
}
}