mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
[improvement][chat]Move generation of semantic text info and rewrite of error message to dedicated ResultProcessor.
This commit is contained in:
@@ -42,7 +42,11 @@ public class Agent extends RecordInfo {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableSearch() {
|
public boolean enableSearch() {
|
||||||
return enableSearch != null && enableSearch == 1;
|
return enableSearch == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean enableFeedback() {
|
||||||
|
return enableFeedback == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean enableMemoryReview() {
|
public boolean enableMemoryReview() {
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.parser;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||||
@@ -15,11 +14,9 @@ import com.tencent.supersonic.common.pojo.enums.AppModule;
|
|||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
@@ -35,7 +32,6 @@ import dev.langchain4j.provider.ModelProvider;
|
|||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.util.CollectionUtils;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@@ -43,8 +39,6 @@ import java.util.HashMap;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||||
@@ -68,26 +62,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
+ "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}"
|
+ "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}"
|
||||||
+ "#Rewritten Question: ";
|
+ "#Rewritten Question: ";
|
||||||
|
|
||||||
public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE";
|
|
||||||
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
|
||||||
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
|
||||||
+ "#Task: Your will be provided with user input, system output and some examples, "
|
|
||||||
+ "please respond shortly to teach user how to ask the right question, "
|
|
||||||
+ "by using `Examples` as references."
|
|
||||||
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
|
|
||||||
+ "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
|
|
||||||
+ "#Examples: {{examples}}\n" + "#Response: ";
|
|
||||||
|
|
||||||
public NL2SQLParser() {
|
public NL2SQLParser() {
|
||||||
ChatAppManager.register(APP_KEY_MULTI_TURN,
|
ChatAppManager.register(APP_KEY_MULTI_TURN,
|
||||||
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
|
ChatApp.builder().prompt(REWRITE_MULTI_TURN_INSTRUCTION).name("多轮对话改写")
|
||||||
.appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false)
|
.appModule(AppModule.CHAT).description("通过大模型根据历史对话来改写本轮对话").enable(false)
|
||||||
.build());
|
.build());
|
||||||
|
|
||||||
ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
|
|
||||||
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
|
|
||||||
.appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
|
|
||||||
.enable(false).build());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -102,8 +81,10 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
ParseResp parseResp = parseContext.getResponse();
|
ParseResp parseResp = parseContext.getResponse();
|
||||||
ChatParseReq parseReq = parseContext.getRequest();
|
ChatParseReq parseReq = parseContext.getRequest();
|
||||||
|
|
||||||
if (!parseContext.getRequest().isDisableLLM()) {
|
if (!parseContext.getRequest().isDisableLLM() && queryNLReq.getText2SQLType().enableLLM()) {
|
||||||
processMultiTurn(parseContext);
|
processMultiTurn(parseContext);
|
||||||
|
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||||
|
parseResp.setUsedExemplars(queryNLReq.getDynamicExemplars());
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
@@ -111,64 +92,16 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (chatCtx != null) {
|
if (chatCtx != null) {
|
||||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||||
}
|
}
|
||||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
|
||||||
|
|
||||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
ParseResp text2SqlParseResp = chatLayerService.parse(queryNLReq);
|
ParseResp text2SqlParseResp = chatLayerService.parse(queryNLReq);
|
||||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||||
} else {
|
|
||||||
if (!parseReq.isDisableLLM()) {
|
|
||||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
|
||||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
|
||||||
parseResp.setState(text2SqlParseResp.getState());
|
parseResp.setState(text2SqlParseResp.getState());
|
||||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||||
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
|
parseResp.setErrorMsg(text2SqlParseResp.getErrorMsg());
|
||||||
formatParseResult(parseResp);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void formatParseResult(ParseResp parseResp) {
|
|
||||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
|
||||||
for (SemanticParseInfo parseInfo : selectedParses) {
|
|
||||||
formatParseInfo(parseInfo);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void formatParseInfo(SemanticParseInfo parseInfo) {
|
|
||||||
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
|
||||||
formatNL2SQLParseInfo(parseInfo);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
|
|
||||||
StringBuilder textBuilder = new StringBuilder();
|
|
||||||
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
|
|
||||||
Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst();
|
|
||||||
metric.ifPresent(schemaElement -> textBuilder.append("**指标:** ")
|
|
||||||
.append(schemaElement.getName()).append(" "));
|
|
||||||
List<String> dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName)
|
|
||||||
.filter(Objects::nonNull).collect(Collectors.toList());
|
|
||||||
if (!CollectionUtils.isEmpty(dimensionNames)) {
|
|
||||||
textBuilder.append("**维度:** ").append(String.join(",", dimensionNames));
|
|
||||||
}
|
|
||||||
textBuilder.append("\n\n**筛选条件:** \n");
|
|
||||||
if (parseInfo.getDateInfo() != null) {
|
|
||||||
textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate())
|
|
||||||
.append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
|
|
||||||
}
|
|
||||||
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())
|
|
||||||
|| CollectionUtils.isEmpty(parseInfo.getMetricFilters())) {
|
|
||||||
Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters();
|
|
||||||
queryFilters.addAll(parseInfo.getMetricFilters());
|
|
||||||
for (QueryFilter queryFilter : queryFilters) {
|
|
||||||
textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ")
|
|
||||||
.append(queryFilter.getOperator().getValue()).append(" ")
|
|
||||||
.append(queryFilter.getValue()).append(" ");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
parseInfo.setTextInfo(textBuilder.toString());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void processMultiTurn(ParseContext parseContext) {
|
private void processMultiTurn(ParseContext parseContext) {
|
||||||
@@ -214,35 +147,6 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
currentMapResult.getQueryText(), rewrittenQuery);
|
currentMapResult.getQueryText(), rewrittenQuery);
|
||||||
}
|
}
|
||||||
|
|
||||||
private String rewriteErrorMessage(ParseContext parseContext, String errMsg,
|
|
||||||
List<Text2SQLExemplar> similarExemplars) {
|
|
||||||
|
|
||||||
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
|
|
||||||
if (Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
|
||||||
return errMsg;
|
|
||||||
}
|
|
||||||
|
|
||||||
Map<String, Object> variables = new HashMap<>();
|
|
||||||
variables.put("user_question", parseContext.getRequest().getQueryText());
|
|
||||||
variables.put("system_message", errMsg);
|
|
||||||
|
|
||||||
StringBuilder exampleStr = new StringBuilder();
|
|
||||||
similarExemplars.forEach(e -> exampleStr.append(
|
|
||||||
String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
|
|
||||||
parseContext.getAgent().getExamples()
|
|
||||||
.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
|
|
||||||
variables.put("examples", exampleStr);
|
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
|
|
||||||
ChatLanguageModel chatLanguageModel =
|
|
||||||
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
|
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
|
||||||
String rewrittenMsg = response.content().text();
|
|
||||||
keyPipelineLog.info("ErrorRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
|
|
||||||
|
|
||||||
return rewrittenMsg;
|
|
||||||
}
|
|
||||||
|
|
||||||
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
private String generateSchemaPrompt(List<SchemaElementMatch> elementMatches) {
|
||||||
List<String> metrics = new ArrayList<>();
|
List<String> metrics = new ArrayList<>();
|
||||||
List<String> dimensions = new ArrayList<>();
|
List<String> dimensions = new ArrayList<>();
|
||||||
|
|||||||
@@ -0,0 +1,72 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.processor.parse;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.AppModule;
|
||||||
|
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||||
|
import com.tencent.supersonic.headless.server.utils.ModelConfigHelper;
|
||||||
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
|
import dev.langchain4j.model.input.Prompt;
|
||||||
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
|
import dev.langchain4j.model.output.Response;
|
||||||
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
|
public class ErrorMessageProcessor implements ParseResultProcessor {
|
||||||
|
|
||||||
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
|
public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE";
|
||||||
|
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
||||||
|
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||||
|
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||||
|
+ "please respond shortly to teach user how to ask the right question, "
|
||||||
|
+ "by using `Examples` as references."
|
||||||
|
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n"
|
||||||
|
+ "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
|
||||||
|
+ "#Examples: {{examples}}\n" + "#Response: ";
|
||||||
|
|
||||||
|
public ErrorMessageProcessor() {
|
||||||
|
ChatAppManager.register(APP_KEY_ERROR_MESSAGE,
|
||||||
|
ChatApp.builder().prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
|
||||||
|
.appModule(AppModule.CHAT).description("通过大模型将异常信息改写为更友好和引导性的提示用语")
|
||||||
|
.enable(false).build());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void process(ParseContext parseContext) {
|
||||||
|
String errMsg = parseContext.getResponse().getErrorMsg();
|
||||||
|
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
|
||||||
|
if (StringUtils.isBlank(errMsg) || Objects.isNull(chatApp) || !chatApp.isEnable()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, Object> variables = new HashMap<>();
|
||||||
|
variables.put("user_question", parseContext.getRequest().getQueryText());
|
||||||
|
variables.put("system_message", errMsg);
|
||||||
|
|
||||||
|
StringBuilder exampleStr = new StringBuilder();
|
||||||
|
parseContext.getResponse().getUsedExemplars().forEach(e -> exampleStr.append(
|
||||||
|
String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
|
||||||
|
parseContext.getAgent().getExamples()
|
||||||
|
.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||||
|
variables.put("examples", exampleStr);
|
||||||
|
|
||||||
|
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
|
||||||
|
ChatLanguageModel chatLanguageModel =
|
||||||
|
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
|
||||||
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
String rewrittenMsg = response.content().text();
|
||||||
|
parseContext.getResponse().setErrorMsg(rewrittenMsg);
|
||||||
|
keyPipelineLog.info("ErrorMessageProcessor modelReq:\n{} \nmodelResp:\n{}", prompt.text(),
|
||||||
|
rewrittenMsg);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.processor.parse;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
public class TextInfoProcessor implements ParseResultProcessor {
|
||||||
|
@Override
|
||||||
|
public void process(ParseContext parseContext) {
|
||||||
|
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||||
|
if (!PluginQueryManager.isPluginQuery(p.getQueryMode())) {
|
||||||
|
formatNL2SQLParseInfo(p);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void formatNL2SQLParseInfo(SemanticParseInfo parseInfo) {
|
||||||
|
StringBuilder textBuilder = new StringBuilder();
|
||||||
|
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
|
||||||
|
Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst();
|
||||||
|
metric.ifPresent(schemaElement -> textBuilder.append("**指标:** ")
|
||||||
|
.append(schemaElement.getName()).append(" "));
|
||||||
|
List<String> dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName)
|
||||||
|
.filter(Objects::nonNull).collect(Collectors.toList());
|
||||||
|
if (!CollectionUtils.isEmpty(dimensionNames)) {
|
||||||
|
textBuilder.append("**维度:** ").append(String.join(",", dimensionNames));
|
||||||
|
}
|
||||||
|
textBuilder.append("\n\n**筛选条件:** \n");
|
||||||
|
if (parseInfo.getDateInfo() != null) {
|
||||||
|
textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate())
|
||||||
|
.append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
|
||||||
|
}
|
||||||
|
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())
|
||||||
|
|| CollectionUtils.isEmpty(parseInfo.getMetricFilters())) {
|
||||||
|
Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters();
|
||||||
|
queryFilters.addAll(parseInfo.getMetricFilters());
|
||||||
|
for (QueryFilter queryFilter : queryFilters) {
|
||||||
|
textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ")
|
||||||
|
.append(queryFilter.getOperator().getValue()).append(" ")
|
||||||
|
.append(queryFilter.getValue()).append(" ");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parseInfo.setTextInfo(textBuilder.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -12,13 +12,13 @@ import lombok.ToString;
|
|||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@NoArgsConstructor
|
@NoArgsConstructor
|
||||||
public class SchemaElementMatch {
|
public class SchemaElementMatch {
|
||||||
SchemaElement element;
|
private SchemaElement element;
|
||||||
double offset;
|
private double offset;
|
||||||
double similarity;
|
private double similarity;
|
||||||
String detectWord;
|
private String detectWord;
|
||||||
String word;
|
private String word;
|
||||||
Long frequency;
|
private Long frequency;
|
||||||
boolean isInherited;
|
private boolean isInherited;
|
||||||
|
|
||||||
public boolean isFullMatched() {
|
public boolean isFullMatched() {
|
||||||
return 1.0 == similarity;
|
return 1.0 == similarity;
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.api.pojo;
|
|||||||
|
|
||||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import lombok.Getter;
|
||||||
import org.apache.commons.collections4.CollectionUtils;
|
import org.apache.commons.collections4.CollectionUtils;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -9,9 +10,10 @@ import java.util.List;
|
|||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Getter
|
||||||
public class SchemaMapInfo {
|
public class SchemaMapInfo {
|
||||||
|
|
||||||
private Map<Long, List<SchemaElementMatch>> dataSetElementMatches = new HashMap<>();
|
private final Map<Long, List<SchemaElementMatch>> dataSetElementMatches = new HashMap<>();
|
||||||
|
|
||||||
public Set<Long> getMatchedDataSetInfos() {
|
public Set<Long> getMatchedDataSetInfos() {
|
||||||
return dataSetElementMatches.keySet();
|
return dataSetElementMatches.keySet();
|
||||||
@@ -21,10 +23,6 @@ public class SchemaMapInfo {
|
|||||||
return dataSetElementMatches.getOrDefault(dataSet, Lists.newArrayList());
|
return dataSetElementMatches.getOrDefault(dataSet, Lists.newArrayList());
|
||||||
}
|
}
|
||||||
|
|
||||||
public Map<Long, List<SchemaElementMatch>> getDataSetElementMatches() {
|
|
||||||
return dataSetElementMatches;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
|
public void setMatchedElements(Long dataSet, List<SchemaElementMatch> elementMatches) {
|
||||||
dataSetElementMatches.put(dataSet, elementMatches);
|
dataSetElementMatches.put(dataSet, elementMatches);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,6 +31,15 @@ public class QueryNLReq extends SemanticQueryReq {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toCustomizedString() {
|
public String toCustomizedString() {
|
||||||
return "";
|
StringBuilder stringBuilder = new StringBuilder("{");
|
||||||
|
stringBuilder.append("\"queryText\":").append(dataSetId);
|
||||||
|
stringBuilder.append("\"dataSetId\":").append(dataSetId);
|
||||||
|
stringBuilder.append("\"modelIds\":").append(modelIds);
|
||||||
|
stringBuilder.append(",\"params\":").append(params);
|
||||||
|
stringBuilder.append(",\"cacheInfo\":").append(cacheInfo);
|
||||||
|
stringBuilder.append(",\"mapMode\":").append(mapModeEnum);
|
||||||
|
stringBuilder.append(",\"dataType\":").append(queryDataType);
|
||||||
|
stringBuilder.append('}');
|
||||||
|
return stringBuilder.toString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ import java.util.Objects;
|
|||||||
public class QuerySqlReq extends SemanticQueryReq {
|
public class QuerySqlReq extends SemanticQueryReq {
|
||||||
|
|
||||||
private String sql;
|
private String sql;
|
||||||
|
|
||||||
private Integer limit = 1000;
|
private Integer limit = 1000;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import lombok.Data;
|
|||||||
public class MapResp {
|
public class MapResp {
|
||||||
|
|
||||||
private final String queryText;
|
private final String queryText;
|
||||||
|
private final SchemaMapInfo mapInfo;
|
||||||
|
|
||||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
public MapResp(String queryText, SchemaMapInfo schemaMapInfo) {
|
||||||
|
|
||||||
public MapResp(String queryText) {
|
|
||||||
this.queryText = queryText;
|
this.queryText = queryText;
|
||||||
|
this.mapInfo = schemaMapInfo;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package com.tencent.supersonic.headless.api.pojo.response;
|
package com.tencent.supersonic.headless.api.pojo.response;
|
||||||
|
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ public class ParseResp {
|
|||||||
private String errorMsg;
|
private String errorMsg;
|
||||||
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
private List<SemanticParseInfo> selectedParses = Lists.newArrayList();
|
||||||
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
|
private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp();
|
||||||
|
private List<Text2SQLExemplar> usedExemplars;
|
||||||
|
|
||||||
public enum ParseState {
|
public enum ParseState {
|
||||||
COMPLETED, PENDING, FAILED
|
COMPLETED, PENDING, FAILED
|
||||||
|
|||||||
@@ -114,8 +114,7 @@ public abstract class BaseMapper implements SchemaMapper {
|
|||||||
return element.getAlias();
|
return element.getAlias();
|
||||||
}
|
}
|
||||||
|
|
||||||
public <T> List<T> getMatches(ChatQueryContext chatQueryContext,
|
public <T> List<T> getMatches(ChatQueryContext chatQueryContext, MatchStrategy matchStrategy) {
|
||||||
BaseMatchStrategy matchStrategy) {
|
|
||||||
String queryText = chatQueryContext.getRequest().getQueryText();
|
String queryText = chatQueryContext.getRequest().getQueryText();
|
||||||
List<S2Term> terms =
|
List<S2Term> terms =
|
||||||
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds());
|
||||||
|
|||||||
@@ -64,11 +64,9 @@ public class S2ChatLayerService implements ChatLayerService {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MapResp map(QueryNLReq queryNLReq) {
|
public MapResp map(QueryNLReq queryNLReq) {
|
||||||
MapResp mapResp = new MapResp(queryNLReq.getQueryText());
|
|
||||||
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
||||||
ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
|
ComponentFactory.getSchemaMappers().forEach(mapper -> mapper.map(queryCtx));
|
||||||
mapResp.setMapInfo(queryCtx.getMapInfo());
|
return new MapResp(queryNLReq.getQueryText(), queryCtx.getMapInfo());
|
||||||
return mapResp;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
|||||||
|
|
||||||
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor,\
|
||||||
com.tencent.supersonic.chat.server.processor.parse.TimeCostProcessor
|
com.tencent.supersonic.chat.server.processor.parse.TimeCostProcessor,\
|
||||||
|
com.tencent.supersonic.chat.server.processor.parse.ErrorMessageProcessor,\
|
||||||
|
com.tencent.supersonic.chat.server.processor.parse.TextInfoProcessor
|
||||||
|
|
||||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||||
|
|||||||
Reference in New Issue
Block a user