mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
[fix][launcher]Enable multi-turn conversation in S2VisitsDemo.
This commit is contained in:
@@ -81,7 +81,7 @@ public class Agent extends RecordInfo {
|
|||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN));
|
||||||
}
|
}
|
||||||
|
|
||||||
public boolean containsLLMParserTool() {
|
public boolean containsLLMTool() {
|
||||||
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.service.ChatContextService;
|
|||||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -87,10 +88,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel =
|
if (parseContext.enbaleLLM()) {
|
||||||
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
|
processMultiTurn(parseContext);
|
||||||
|
}
|
||||||
processMultiTurn(chatLanguageModel, parseContext);
|
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||||
|
|
||||||
@@ -99,13 +99,15 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||||
} else {
|
} else {
|
||||||
parseResp.setErrorMsg(
|
if (parseContext.enbaleLLM()) {
|
||||||
rewriteErrorMessage(
|
parseResp.setErrorMsg(
|
||||||
chatLanguageModel,
|
rewriteErrorMessage(
|
||||||
parseContext.getQueryText(),
|
parseContext.getQueryText(),
|
||||||
text2SqlParseResp.getErrorMsg(),
|
text2SqlParseResp.getErrorMsg(),
|
||||||
queryNLReq.getDynamicExemplars(),
|
queryNLReq.getDynamicExemplars(),
|
||||||
parseContext.getAgent().getExamples()));
|
parseContext.getAgent().getExamples(),
|
||||||
|
parseContext.getAgent().getModelConfig()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
parseResp.setState(text2SqlParseResp.getState());
|
parseResp.setState(text2SqlParseResp.getState());
|
||||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||||
@@ -178,7 +180,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
parseInfo.setTextInfo(textBuilder.toString());
|
parseInfo.setTextInfo(textBuilder.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
private void processMultiTurn(ParseContext parseContext) {
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
||||||
Boolean globalMultiTurnConfig =
|
Boolean globalMultiTurnConfig =
|
||||||
@@ -192,6 +194,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChatLanguageModel chatLanguageModel =
|
||||||
|
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
|
||||||
|
|
||||||
// derive mapping result of current question and parsing result of last question.
|
// derive mapping result of current question and parsing result of last question.
|
||||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||||
@@ -235,11 +240,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private String rewriteErrorMessage(
|
private String rewriteErrorMessage(
|
||||||
ChatLanguageModel chatLanguageModel,
|
|
||||||
String userQuestion,
|
String userQuestion,
|
||||||
String errMsg,
|
String errMsg,
|
||||||
List<Text2SQLExemplar> similarExemplars,
|
List<Text2SQLExemplar> similarExemplars,
|
||||||
List<String> agentExamples) {
|
List<String> agentExamples,
|
||||||
|
ChatModelConfig modelConfig) {
|
||||||
Map<String, Object> variables = new HashMap<>();
|
Map<String, Object> variables = new HashMap<>();
|
||||||
variables.put("user_question", userQuestion);
|
variables.put("user_question", userQuestion);
|
||||||
variables.put("system_message", errMsg);
|
variables.put("system_message", errMsg);
|
||||||
@@ -256,6 +261,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
||||||
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
String rewrittenMsg = response.content().text();
|
String rewrittenMsg = response.content().text();
|
||||||
|
|||||||
@@ -22,4 +22,11 @@ public class ParseContext {
|
|||||||
}
|
}
|
||||||
return agent.containsNL2SQLTool();
|
return agent.containsNL2SQLTool();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean enbaleLLM() {
|
||||||
|
if (agent == null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return agent.containsLLMTool();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
}
|
}
|
||||||
|
|
||||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||||
if (!agent.containsLLMParserTool()
|
if (!agent.containsLLMTool()
|
||||||
|| !LLMConnHelper.testConnection(agent.getModelConfig())
|
|| !LLMConnHelper.testConnection(agent.getModelConfig())
|
||||||
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||||
return;
|
return;
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ public class QueryReqConverter {
|
|||||||
return queryNLReq;
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean hasLLMTool = agent.containsLLMParserTool();
|
boolean hasLLMTool = agent.containsLLMTool();
|
||||||
boolean hasRuleTool = agent.containsRuleTool();
|
boolean hasRuleTool = agent.containsRuleTool();
|
||||||
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
|
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
|
||||||
|
|
||||||
|
|||||||
@@ -38,9 +38,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
|||||||
+ "\n4.DO NOT calculate date range using functions."
|
+ "\n4.DO NOT calculate date range using functions."
|
||||||
+ "\n5.DO NOT calculate date range using DATE_SUB."
|
+ "\n5.DO NOT calculate date range using DATE_SUB."
|
||||||
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
+ "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
|
||||||
+ "\n7.ALWAYS USE `with` statement to handle secondary calculation scenario."
|
|
||||||
+ "\n#Exemplars:\n{{exemplar}}"
|
+ "\n#Exemplars:\n{{exemplar}}"
|
||||||
+ "#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
+ "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}";
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
static class SemanticSql {
|
static class SemanticSql {
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package com.tencent.supersonic.headless.server.rest;
|
package com.tencent.supersonic.headless.server.rest;
|
||||||
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
|
||||||
import javax.servlet.http.HttpServletRequest;
|
import javax.servlet.http.HttpServletRequest;
|
||||||
import javax.servlet.http.HttpServletResponse;
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.TermReq;
|
import com.tencent.supersonic.headless.api.pojo.request.TermReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.TermResp;
|
import com.tencent.supersonic.headless.api.pojo.response.TermResp;
|
||||||
import com.tencent.supersonic.headless.server.service.TermService;
|
import com.tencent.supersonic.headless.server.service.TermService;
|
||||||
@@ -38,8 +38,9 @@ public class TermController {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@GetMapping
|
@GetMapping
|
||||||
public List<TermResp> getTerms(@RequestParam("domainId") Long domainId,
|
public List<TermResp> getTerms(
|
||||||
@RequestParam(name = "queryKey", required = false) String queryKey) {
|
@RequestParam("domainId") Long domainId,
|
||||||
|
@RequestParam(name = "queryKey", required = false) String queryKey) {
|
||||||
return termService.getTerms(domainId, queryKey);
|
return termService.getTerms(domainId, queryKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -55,5 +56,4 @@ public class TermController {
|
|||||||
termService.deleteBatch(metaBatchReq);
|
termService.deleteBatch(metaBatchReq);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,13 +56,15 @@ public class TermServiceImpl extends ServiceImpl<TermMapper, TermDO> implements
|
|||||||
QueryWrapper<TermDO> queryWrapper = new QueryWrapper<>();
|
QueryWrapper<TermDO> queryWrapper = new QueryWrapper<>();
|
||||||
queryWrapper.lambda().eq(TermDO::getDomainId, domainId);
|
queryWrapper.lambda().eq(TermDO::getDomainId, domainId);
|
||||||
if (StringUtils.isNotBlank(queryKey)) {
|
if (StringUtils.isNotBlank(queryKey)) {
|
||||||
queryWrapper.lambda().and(i ->
|
queryWrapper
|
||||||
i.like(TermDO::getName, queryKey)
|
.lambda()
|
||||||
.or()
|
.and(
|
||||||
.like(TermDO::getDescription, queryKey)
|
i ->
|
||||||
.or()
|
i.like(TermDO::getName, queryKey)
|
||||||
.like(TermDO::getAlias, queryKey)
|
.or()
|
||||||
);
|
.like(TermDO::getDescription, queryKey)
|
||||||
|
.or()
|
||||||
|
.like(TermDO::getAlias, queryKey));
|
||||||
}
|
}
|
||||||
List<TermDO> termDOS = list(queryWrapper);
|
List<TermDO> termDOS = list(queryWrapper);
|
||||||
return termDOS.stream().map(this::convert).collect(Collectors.toList());
|
return termDOS.stream().map(this::convert).collect(Collectors.toList());
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
|||||||
agentConfig.getTools().add(llmParserTool);
|
agentConfig.getTools().add(llmParserTool);
|
||||||
}
|
}
|
||||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false);
|
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true);
|
||||||
agent.setMultiTurnConfig(multiTurnConfig);
|
agent.setMultiTurnConfig(multiTurnConfig);
|
||||||
Agent agentCreated = agentService.createAgent(agent, User.getFakeUser());
|
Agent agentCreated = agentService.createAgent(agent, User.getFakeUser());
|
||||||
return agentCreated.getId();
|
return agentCreated.getId();
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.test.context.TestPropertySource;
|
||||||
|
|
||||||
import java.time.LocalDate;
|
import java.time.LocalDate;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
@@ -20,6 +21,7 @@ import java.util.stream.Collectors;
|
|||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@TestPropertySource(properties = {"s2.demo.enableLLM = false"})
|
||||||
public class BaseTest extends BaseApplication {
|
public class BaseTest extends BaseApplication {
|
||||||
|
|
||||||
protected final int unit = 7;
|
protected final int unit = 7;
|
||||||
|
|||||||
Reference in New Issue
Block a user