|
|
|
|
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.service.ChatContextService;
|
|
|
|
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
|
|
|
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
|
|
|
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
|
|
|
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
|
|
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
|
|
|
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
|
|
|
|
import com.tencent.supersonic.common.util.ContextUtils;
|
|
|
|
|
@@ -87,10 +88,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|
|
|
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
|
|
|
|
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
|
|
|
|
|
|
|
|
|
ChatLanguageModel chatLanguageModel =
|
|
|
|
|
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
|
|
|
|
|
|
|
|
|
|
processMultiTurn(chatLanguageModel, parseContext);
|
|
|
|
|
if (parseContext.enbaleLLM()) {
|
|
|
|
|
processMultiTurn(parseContext);
|
|
|
|
|
}
|
|
|
|
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
|
|
|
|
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
|
|
|
|
|
|
|
|
|
@@ -99,13 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
|
|
|
|
|
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
|
|
|
|
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
|
|
|
|
} else {
|
|
|
|
|
parseResp.setErrorMsg(
|
|
|
|
|
rewriteErrorMessage(
|
|
|
|
|
chatLanguageModel,
|
|
|
|
|
parseContext.getQueryText(),
|
|
|
|
|
text2SqlParseResp.getErrorMsg(),
|
|
|
|
|
queryNLReq.getDynamicExemplars(),
|
|
|
|
|
parseContext.getAgent().getExamples()));
|
|
|
|
|
if (parseContext.enbaleLLM()) {
|
|
|
|
|
parseResp.setErrorMsg(
|
|
|
|
|
rewriteErrorMessage(
|
|
|
|
|
parseContext.getQueryText(),
|
|
|
|
|
text2SqlParseResp.getErrorMsg(),
|
|
|
|
|
queryNLReq.getDynamicExemplars(),
|
|
|
|
|
parseContext.getAgent().getExamples(),
|
|
|
|
|
parseContext.getAgent().getModelConfig()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
parseResp.setState(text2SqlParseResp.getState());
|
|
|
|
|
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
|
|
|
|
@@ -178,7 +180,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|
|
|
|
parseInfo.setTextInfo(textBuilder.toString());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
|
|
|
|
private void processMultiTurn(ParseContext parseContext) {
|
|
|
|
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
|
|
|
|
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
|
|
|
|
Boolean globalMultiTurnConfig =
|
|
|
|
|
@@ -192,6 +194,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ChatLanguageModel chatLanguageModel =
|
|
|
|
|
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
|
|
|
|
|
|
|
|
|
|
// derive mapping result of current question and parsing result of last question.
|
|
|
|
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
|
|
|
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
|
|
|
|
@@ -235,11 +240,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private String rewriteErrorMessage(
|
|
|
|
|
ChatLanguageModel chatLanguageModel,
|
|
|
|
|
String userQuestion,
|
|
|
|
|
String errMsg,
|
|
|
|
|
List<Text2SQLExemplar> similarExemplars,
|
|
|
|
|
List<String> agentExamples) {
|
|
|
|
|
List<String> agentExamples,
|
|
|
|
|
ChatModelConfig modelConfig) {
|
|
|
|
|
Map<String, Object> variables = new HashMap<>();
|
|
|
|
|
variables.put("user_question", userQuestion);
|
|
|
|
|
variables.put("system_message", errMsg);
|
|
|
|
|
@@ -256,6 +261,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|
|
|
|
|
|
|
|
|
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
|
|
|
|
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
|
|
|
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
|
|
|
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
|
|
|
|
|
|
|
|
|
String rewrittenMsg = response.content().text();
|
|
|
|
|
|