[fix][launcher]Enable multi-turn conversation in S2VisitsDemo.

This commit is contained in:
jerryjzhang
2024-09-29 14:14:59 +08:00
parent 299fd8413a
commit bfdf9004ea
10 changed files with 47 additions and 31 deletions

View File

@@ -81,7 +81,7 @@ public class Agent extends RecordInfo {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN)); return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
} }
public boolean containsLLMParserTool() { public boolean containsLLMTool() {
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
} }

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -87,10 +88,9 @@ public class NL2SQLParser implements ChatQueryParser {
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId()); ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
ChatLanguageModel chatLanguageModel = if (parseContext.enbaleLLM()) {
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig()); processMultiTurn(parseContext);
}
processMultiTurn(chatLanguageModel, parseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
@@ -99,13 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else { } else {
if (parseContext.enbaleLLM()) {
parseResp.setErrorMsg( parseResp.setErrorMsg(
rewriteErrorMessage( rewriteErrorMessage(
chatLanguageModel,
parseContext.getQueryText(), parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(), text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(), queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples())); parseContext.getAgent().getExamples(),
parseContext.getAgent().getModelConfig()));
}
} }
parseResp.setState(text2SqlParseResp.getState()); parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
@@ -178,7 +180,7 @@ public class NL2SQLParser implements ChatQueryParser {
parseInfo.setTextInfo(textBuilder.toString()); parseInfo.setTextInfo(textBuilder.toString());
} }
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) { private void processMultiTurn(ParseContext parseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig(); MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean globalMultiTurnConfig =
@@ -192,6 +194,9 @@ public class NL2SQLParser implements ChatQueryParser {
return; return;
} }
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
// derive mapping result of current question and parsing result of last question. // derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
@@ -235,11 +240,11 @@ public class NL2SQLParser implements ChatQueryParser {
} }
private String rewriteErrorMessage( private String rewriteErrorMessage(
ChatLanguageModel chatLanguageModel,
String userQuestion, String userQuestion,
String errMsg, String errMsg,
List<Text2SQLExemplar> similarExemplars, List<Text2SQLExemplar> similarExemplars,
List<String> agentExamples) { List<String> agentExamples,
ChatModelConfig modelConfig) {
Map<String, Object> variables = new HashMap<>(); Map<String, Object> variables = new HashMap<>();
variables.put("user_question", userQuestion); variables.put("user_question", userQuestion);
variables.put("system_message", errMsg); variables.put("system_message", errMsg);
@@ -256,6 +261,7 @@ public class NL2SQLParser implements ChatQueryParser {
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables); Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text()); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenMsg = response.content().text(); String rewrittenMsg = response.content().text();

View File

@@ -22,4 +22,11 @@ public class ParseContext {
} }
return agent.containsNL2SQLTool(); return agent.containsNL2SQLTool();
} }
public boolean enbaleLLM() {
if (agent == null) {
return true;
}
return agent.containsLLMTool();
}
} }

View File

@@ -98,7 +98,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
} }
private synchronized void doExecuteAgentExamples(Agent agent) { private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMParserTool() if (!agent.containsLLMTool()
|| !LLMConnHelper.testConnection(agent.getModelConfig()) || !LLMConnHelper.testConnection(agent.getModelConfig())
|| CollectionUtils.isEmpty(agent.getExamples())) { || CollectionUtils.isEmpty(agent.getExamples())) {
return; return;

View File

@@ -24,7 +24,7 @@ public class QueryReqConverter {
return queryNLReq; return queryNLReq;
} }
boolean hasLLMTool = agent.containsLLMParserTool(); boolean hasLLMTool = agent.containsLLMTool();
boolean hasRuleTool = agent.containsRuleTool(); boolean hasRuleTool = agent.containsRuleTool();
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig()); boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());

View File

@@ -38,9 +38,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "\n4.DO NOT calculate date range using functions." + "\n4.DO NOT calculate date range using functions."
+ "\n5.DO NOT calculate date range using DATE_SUB." + "\n5.DO NOT calculate date range using DATE_SUB."
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
+ "\n7.ALWAYS USE `with` statement to handle secondary calculation scenario."
+ "\n#Exemplars:\n{{exemplar}}" + "\n#Exemplars:\n{{exemplar}}"
+ "#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
@Data @Data
static class SemanticSql { static class SemanticSql {

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.headless.server.rest; package com.tencent.supersonic.headless.server.rest;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
import com.tencent.supersonic.headless.api.pojo.request.TermReq; import com.tencent.supersonic.headless.api.pojo.request.TermReq;
import com.tencent.supersonic.headless.api.pojo.response.TermResp; import com.tencent.supersonic.headless.api.pojo.response.TermResp;
import com.tencent.supersonic.headless.server.service.TermService; import com.tencent.supersonic.headless.server.service.TermService;
@@ -38,7 +38,8 @@ public class TermController {
} }
@GetMapping @GetMapping
public List<TermResp> getTerms(@RequestParam("domainId") Long domainId, public List<TermResp> getTerms(
@RequestParam("domainId") Long domainId,
@RequestParam(name = "queryKey", required = false) String queryKey) { @RequestParam(name = "queryKey", required = false) String queryKey) {
return termService.getTerms(domainId, queryKey); return termService.getTerms(domainId, queryKey);
} }
@@ -55,5 +56,4 @@ public class TermController {
termService.deleteBatch(metaBatchReq); termService.deleteBatch(metaBatchReq);
return true; return true;
} }
} }

View File

@@ -56,13 +56,15 @@ public class TermServiceImpl extends ServiceImpl<TermMapper, TermDO> implements
QueryWrapper<TermDO> queryWrapper = new QueryWrapper<>(); QueryWrapper<TermDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(TermDO::getDomainId, domainId); queryWrapper.lambda().eq(TermDO::getDomainId, domainId);
if (StringUtils.isNotBlank(queryKey)) { if (StringUtils.isNotBlank(queryKey)) {
queryWrapper.lambda().and(i -> queryWrapper
.lambda()
.and(
i ->
i.like(TermDO::getName, queryKey) i.like(TermDO::getName, queryKey)
.or() .or()
.like(TermDO::getDescription, queryKey) .like(TermDO::getDescription, queryKey)
.or() .or()
.like(TermDO::getAlias, queryKey) .like(TermDO::getAlias, queryKey));
);
} }
List<TermDO> termDOS = list(queryWrapper); List<TermDO> termDOS = list(queryWrapper);
return termDOS.stream().map(this::convert).collect(Collectors.toList()); return termDOS.stream().map(this::convert).collect(Collectors.toList());

View File

@@ -171,7 +171,7 @@ public class S2VisitsDemo extends S2BaseDemo {
agentConfig.getTools().add(llmParserTool); agentConfig.getTools().add(llmParserTool);
} }
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false); MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true);
agent.setMultiTurnConfig(multiTurnConfig); agent.setMultiTurnConfig(multiTurnConfig);
Agent agentCreated = agentService.createAgent(agent, User.getFakeUser()); Agent agentCreated = agentService.createAgent(agent, User.getFakeUser());
return agentCreated.getId(); return agentCreated.getId();

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.test.context.TestPropertySource;
import java.time.LocalDate; import java.time.LocalDate;
import java.util.Set; import java.util.Set;
@@ -20,6 +21,7 @@ import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
@TestPropertySource(properties = {"s2.demo.enableLLM = false"})
public class BaseTest extends BaseApplication { public class BaseTest extends BaseApplication {
protected final int unit = 7; protected final int unit = 7;