diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java index 3b98400dc..101df3632 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java @@ -3,9 +3,15 @@ package com.tencent.supersonic.chat.api.pojo.request; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import lombok.AllArgsConstructor; +import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; @Data +@NoArgsConstructor +@AllArgsConstructor +@Builder public class ChatParseReq { private String queryText; private Integer chatId; @@ -15,4 +21,5 @@ public class ChatParseReq { private QueryFilters queryFilters; private boolean saveAnswer = true; private SchemaMapInfo mapInfo = new SchemaMapInfo(); + private boolean disableLLM = false; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java index 233a255a3..f3716be85 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ParseContext.java @@ -15,17 +15,18 @@ public class ParseContext { private QueryFilters queryFilters; private boolean saveAnswer = true; private SchemaMapInfo mapInfo; + private boolean disableLLM = false; public boolean enableNL2SQL() { if (agent == null) { - return true; + return false; } return agent.containsNL2SQLTool(); } public boolean enbaleLLM() { - if (agent == null) { - return true; + if (agent == null || disableLLM) { + return false; } return agent.containsLLMTool(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatQueryService.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatQueryService.java index 07b29d467..f357160f0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatQueryService.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/ChatQueryService.java @@ -19,7 +19,7 @@ public interface ChatQueryService { QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception; - QueryResult parseAndExecute(int chatId, int agentId, String queryText); + QueryResult parseAndExecute(ChatParseReq chatParseReq); Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 2ea94f443..646622409 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.service.impl; import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.request.ChatMemoryFilter; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; @@ -115,7 +116,8 @@ public class AgentServiceImpl extends ServiceImpl implem continue; } try { - chatQueryService.parseAndExecute(-1, agent.getId(), example); + chatQueryService.parseAndExecute(ChatParseReq.builder().chatId(-1) + .agentId(agent.getId()).queryText(example).build()); } catch (Exception e) { log.warn("agent:{} example execute failed:{}", agent.getName(), example); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index 74488deea..111ab227f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -145,25 +145,21 @@ public class ChatQueryServiceImpl implements ChatQueryService { } @Override - public QueryResult parseAndExecute(int chatId, int agentId, String queryText) { - ChatParseReq chatParseReq = new ChatParseReq(); - chatParseReq.setQueryText(queryText); - chatParseReq.setChatId(chatId); - chatParseReq.setAgentId(agentId); - chatParseReq.setUser(User.getFakeUser()); + public QueryResult parseAndExecute(ChatParseReq chatParseReq) { ParseResp parseResp = performParsing(chatParseReq); if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) { log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty", - chatId, agentId, queryText); + chatParseReq.getChatId(), chatParseReq.getAgentId(), + chatParseReq.getQueryText()); return null; } ChatExecuteReq executeReq = new ChatExecuteReq(); executeReq.setQueryId(parseResp.getQueryId()); executeReq.setParseId(parseResp.getSelectedParses().get(0).getId()); - executeReq.setQueryText(queryText); - executeReq.setChatId(chatId); + executeReq.setQueryText(chatParseReq.getQueryText()); + executeReq.setChatId(chatParseReq.getChatId()); executeReq.setUser(User.getFakeUser()); - executeReq.setAgentId(agentId); + executeReq.setAgentId(chatParseReq.getAgentId()); executeReq.setSaveAnswer(true); return performExecution(executeReq); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java index f6d067515..7f01c0472 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -28,7 +28,9 @@ public class QueryReqConverter { boolean hasRuleTool = agent.containsRuleTool(); boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig()); - if (hasLLMTool && hasLLMConfig) { + if (parseContext.isDisableLLM()) { + queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); + } else if (hasLLMTool && hasLLMConfig) { queryNLReq.setText2SQLType(Text2SQLType.ONLY_LLM); } else if (hasLLMTool && hasRuleTool) { queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM); @@ -37,6 +39,7 @@ public class QueryReqConverter { } else if (hasRuleTool) { queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); } + queryNLReq.setDataSetIds(agent.getDataSetIds()); if (Objects.nonNull(queryNLReq.getMapInfo()) && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 9994b3cc8..f8f78fc82 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -5,6 +5,7 @@ import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentConfig; import com.tencent.supersonic.chat.server.agent.AgentToolType; @@ -136,12 +137,16 @@ public class S2VisitsDemo extends S2BaseDemo { public void addSampleChats(Integer agentId) { Long chatId = chatManageService.addChat(user, "样例对话1", agentId); + submitText(chatId.intValue(), agentId, "超音数 访问次数"); + submitText(chatId.intValue(), agentId, "按部门统计"); + submitText(chatId.intValue(), agentId, "查询近30天"); + submitText(chatId.intValue(), agentId, "alice 停留时长"); + submitText(chatId.intValue(), agentId, "访问次数最高的部门"); + } - chatQueryService.parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数"); - chatQueryService.parseAndExecute(chatId.intValue(), agentId, "按部门统计"); - chatQueryService.parseAndExecute(chatId.intValue(), agentId, "查询近30天"); - chatQueryService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长"); - chatQueryService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门"); + private void submitText(int chatId, int agentId, String queryText) { + chatQueryService.parseAndExecute(ChatParseReq.builder().chatId(chatId).agentId(agentId) + .queryText(queryText).user(User.getFakeUser()).disableLLM(true).build()); } private Integer addAgent(long dataSetId) {