[fix][launcher]Enable multi-turn conversation in S2VisitsDemo.

This commit is contained in:
jerryjzhang
2024-09-29 14:14:59 +08:00
parent 299fd8413a
commit bfdf9004ea
10 changed files with 47 additions and 31 deletions

View File

@@ -81,7 +81,7 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
}
public boolean containsLLMParserTool() {
public boolean containsLLMTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
}

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils;
@@ -87,10 +88,9 @@ public class NL2SQLParser implements ChatQueryParser {
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
processMultiTurn(chatLanguageModel, parseContext);
if (parseContext.enbaleLLM()) {
processMultiTurn(parseContext);
}
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
@@ -99,13 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
parseResp.setErrorMsg(
rewriteErrorMessage(
chatLanguageModel,
parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples()));
if (parseContext.enbaleLLM()) {
parseResp.setErrorMsg(
rewriteErrorMessage(
parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples(),
parseContext.getAgent().getModelConfig()));
}
}
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
@@ -178,7 +180,7 @@ public class NL2SQLParser implements ChatQueryParser {
parseInfo.setTextInfo(textBuilder.toString());
}
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
private void processMultiTurn(ParseContext parseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig =
@@ -192,6 +194,9 @@ public class NL2SQLParser implements ChatQueryParser {
return;
}
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
// derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
@@ -235,11 +240,11 @@ public class NL2SQLParser implements ChatQueryParser {
}
private String rewriteErrorMessage(
ChatLanguageModel chatLanguageModel,
String userQuestion,
String errMsg,
List<Text2SQLExemplar> similarExemplars,
List<String> agentExamples) {
List<String> agentExamples,
ChatModelConfig modelConfig) {
Map<String, Object> variables = new HashMap<>();
variables.put("user_question", userQuestion);
variables.put("system_message", errMsg);
@@ -256,6 +261,7 @@ public class NL2SQLParser implements ChatQueryParser {
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenMsg = response.content().text();

View File

@@ -22,4 +22,11 @@ public class ParseContext {
}
return agent.containsNL2SQLTool();
}
public boolean enbaleLLM() {
if (agent == null) {
return true;
}
return agent.containsLLMTool();
}
}

View File

@@ -98,7 +98,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
}
private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMParserTool()
if (!agent.containsLLMTool()
|| !LLMConnHelper.testConnection(agent.getModelConfig())
|| CollectionUtils.isEmpty(agent.getExamples())) {
return;

View File

@@ -24,7 +24,7 @@ public class QueryReqConverter {
return queryNLReq;
}
boolean hasLLMTool = agent.containsLLMParserTool();
boolean hasLLMTool = agent.containsLLMTool();
boolean hasRuleTool = agent.containsRuleTool();
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());