diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 627318011..87ffae79f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -81,7 +81,7 @@ public class Agent extends RecordInfo { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN)); } - public boolean containsLLMParserTool() { + public boolean containsLLMTool() { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index f8fb224ff..9bf9b8419 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.chat.server.service.ChatContextService; import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; +import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ContextUtils; @@ -87,10 +88,9 @@ public class NL2SQLParser implements ChatQueryParser { ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId()); - ChatLanguageModel chatLanguageModel = - ModelProvider.getChatModel(parseContext.getAgent().getModelConfig()); - - processMultiTurn(chatLanguageModel, parseContext); + if (parseContext.enbaleLLM()) { + processMultiTurn(parseContext); + } QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx); addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); @@ -99,13 +99,15 @@ public class NL2SQLParser implements ChatQueryParser { if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); } else { - parseResp.setErrorMsg( - rewriteErrorMessage( - chatLanguageModel, - parseContext.getQueryText(), - text2SqlParseResp.getErrorMsg(), - queryNLReq.getDynamicExemplars(), - parseContext.getAgent().getExamples())); + if (parseContext.enbaleLLM()) { + parseResp.setErrorMsg( + rewriteErrorMessage( + parseContext.getQueryText(), + text2SqlParseResp.getErrorMsg(), + queryNLReq.getDynamicExemplars(), + parseContext.getAgent().getExamples(), + parseContext.getAgent().getModelConfig())); + } } parseResp.setState(text2SqlParseResp.getState()); parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime()); @@ -178,7 +180,7 @@ public class NL2SQLParser implements ChatQueryParser { parseInfo.setTextInfo(textBuilder.toString()); } - private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) { + private void processMultiTurn(ParseContext parseContext) { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig(); Boolean globalMultiTurnConfig = @@ -192,6 +194,9 @@ public class NL2SQLParser implements ChatQueryParser { return; } + ChatLanguageModel chatLanguageModel = + ModelProvider.getChatModel(parseContext.getAgent().getModelConfig()); + // derive mapping result of current question and parsing result of last question. ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); @@ -235,11 +240,11 @@ public class NL2SQLParser implements ChatQueryParser { } private String rewriteErrorMessage( - ChatLanguageModel chatLanguageModel, String userQuestion, String errMsg, List similarExemplars, - List agentExamples) { + List agentExamples, + ChatModelConfig modelConfig) { Map variables = new HashMap<>(); variables.put("user_question", userQuestion); variables.put("system_message", errMsg); @@ -256,6 +261,7 @@ public class NL2SQLParser implements ChatQueryParser { Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables); keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text()); + ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); Response response = chatLanguageModel.generate(prompt.toUserMessage()); String rewrittenMsg = response.content().text(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index 4573b6ef1..233a255a3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -22,4 +22,11 @@ public class ParseContext { } return agent.containsNL2SQLTool(); } + + public boolean enbaleLLM() { + if (agent == null) { + return true; + } + return agent.containsLLMTool(); + } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index a6ef3e7ec..6719d08c6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -98,7 +98,7 @@ public class AgentServiceImpl extends ServiceImpl implem } private synchronized void doExecuteAgentExamples(Agent agent) { - if (!agent.containsLLMParserTool() + if (!agent.containsLLMTool() || !LLMConnHelper.testConnection(agent.getModelConfig()) || CollectionUtils.isEmpty(agent.getExamples())) { return; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index fbe19a895..f6d067515 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -24,7 +24,7 @@ public class QueryReqConverter { return queryNLReq; } - boolean hasLLMTool = agent.containsLLMParserTool(); + boolean hasLLMTool = agent.containsLLMTool(); boolean hasRuleTool = agent.containsRuleTool(); boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index 625a08805..f33e2b0dc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -38,9 +38,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { + "\n4.DO NOT calculate date range using functions." + "\n5.DO NOT calculate date range using DATE_SUB." + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." - + "\n7.ALWAYS USE `with` statement to handle secondary calculation scenario." + "\n#Exemplars:\n{{exemplar}}" - + "#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + + "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; @Data static class SemanticSql { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java index bee51079f..39dfe64a2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java @@ -1,11 +1,11 @@ package com.tencent.supersonic.headless.server.rest; -import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.headless.api.pojo.request.MetaBatchReq; import com.tencent.supersonic.headless.api.pojo.request.TermReq; import com.tencent.supersonic.headless.api.pojo.response.TermResp; import com.tencent.supersonic.headless.server.service.TermService; @@ -38,8 +38,9 @@ public class TermController { } @GetMapping - public List getTerms(@RequestParam("domainId") Long domainId, - @RequestParam(name = "queryKey", required = false) String queryKey) { + public List getTerms( + @RequestParam("domainId") Long domainId, + @RequestParam(name = "queryKey", required = false) String queryKey) { return termService.getTerms(domainId, queryKey); } @@ -55,5 +56,4 @@ public class TermController { termService.deleteBatch(metaBatchReq); return true; } - } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java index c51d57d36..b116d0f45 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java @@ -56,13 +56,15 @@ public class TermServiceImpl extends ServiceImpl implements QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.lambda().eq(TermDO::getDomainId, domainId); if (StringUtils.isNotBlank(queryKey)) { - queryWrapper.lambda().and(i -> - i.like(TermDO::getName, queryKey) - .or() - .like(TermDO::getDescription, queryKey) - .or() - .like(TermDO::getAlias, queryKey) - ); + queryWrapper + .lambda() + .and( + i -> + i.like(TermDO::getName, queryKey) + .or() + .like(TermDO::getDescription, queryKey) + .or() + .like(TermDO::getAlias, queryKey)); } List termDOS = list(queryWrapper); return termDOS.stream().map(this::convert).collect(Collectors.toList()); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index bf07824fa..23ad37215 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -171,7 +171,7 @@ public class S2VisitsDemo extends S2BaseDemo { agentConfig.getTools().add(llmParserTool); } agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false); + MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true); agent.setMultiTurnConfig(multiTurnConfig); Agent agentCreated = agentService.createAgent(agent, User.getFakeUser()); return agentCreated.getId(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index 7ba9527f3..6011e83a4 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -13,6 +13,7 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.util.DataUtils; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.TestPropertySource; import java.time.LocalDate; import java.util.Set; @@ -20,6 +21,7 @@ import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; +@TestPropertySource(properties = {"s2.demo.enableLLM = false"}) public class BaseTest extends BaseApplication { protected final int unit = 7;