diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatQueryExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatQueryExecutor.java index 0a2115820..2ed37998f 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatQueryExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatQueryExecutor.java @@ -5,5 +5,7 @@ import com.tencent.supersonic.chat.server.pojo.ExecuteContext; public interface ChatQueryExecutor { + boolean accept(ExecuteContext executeContext); + QueryResult execute(ExecuteContext executeContext); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index adcab30a3..67332e7f2 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -37,11 +37,12 @@ public class PlainTextExecutor implements ChatQueryExecutor { } @Override - public QueryResult execute(ExecuteContext executeContext) { - if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) { - return null; - } + public boolean accept(ExecuteContext executeContext) { + return "PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode()); + } + @Override + public QueryResult execute(ExecuteContext executeContext) { AgentService agentService = ContextUtils.getBean(AgentService.class); Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java index 90278f54e..7ba80c779 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java @@ -8,6 +8,11 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; public class PluginExecutor implements ChatQueryExecutor { + @Override + public boolean accept(ExecuteContext executeContext) { + return PluginQueryManager.isPluginQuery(executeContext.getParseInfo().getQueryMode()); + } + @Override public QueryResult execute(ExecuteContext executeContext) { SemanticParseInfo parseInfo = executeContext.getParseInfo(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index c378e1be4..0871dbae8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -25,6 +25,11 @@ import java.util.Objects; public class SqlExecutor implements ChatQueryExecutor { + @Override + public boolean accept(ExecuteContext executeContext) { + return true; + } + @SneakyThrows @Override public QueryResult execute(ExecuteContext executeContext) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java index 33410a3f7..9f9b9e46a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatQueryParser.java @@ -4,5 +4,7 @@ import com.tencent.supersonic.chat.server.pojo.ParseContext; public interface ChatQueryParser { + boolean accept(ParseContext parseContext); + void parse(ParseContext parseContext); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java index c48f05da1..026c33714 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java @@ -14,12 +14,12 @@ public class NL2PluginParser implements ChatQueryParser { private final List pluginRecognizers = ComponentFactory.getPluginRecognizers(); + public boolean accept(ParseContext parseContext) { + return parseContext.getAgent().containsPluginTool(); + } + @Override public void parse(ParseContext parseContext) { - if (!parseContext.getAgent().containsPluginTool()) { - return; - } - pluginRecognizers.forEach(pluginRecognizer -> { pluginRecognizer.recognize(parseContext); log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 7022a54cf..8e6199ea0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -73,12 +73,12 @@ public class NL2SQLParser implements ChatQueryParser { .build()); } + public boolean accept(ParseContext parseContext) { + return parseContext.enableNL2SQL(); + } + @Override public void parse(ParseContext parseContext) { - if (!parseContext.enableNL2SQL()) { - return; - } - // first go with rule-based parsers unless the user has already selected one parse. if (Objects.isNull(parseContext.getRequest().getSelectedParse())) { QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java index ff13e30ce..bc304d069 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java @@ -6,12 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp; public class PlainTextParser implements ChatQueryParser { + public boolean accept(ParseContext parseContext) { + return !parseContext.getAgent().containsAnyTool(); + } + @Override public void parse(ParseContext parseContext) { - if (parseContext.getAgent().containsAnyTool()) { - return; - } - SemanticParseInfo parseInfo = new SemanticParseInfo(); parseInfo.setQueryMode("PLAIN_TEXT"); parseInfo.setId(1); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index b1b37897e..484e22305 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -95,7 +95,11 @@ public class ChatQueryServiceImpl implements ChatQueryService { } ParseContext parseContext = buildParseContext(chatParseReq, new ChatParseResp(queryId)); - chatQueryParsers.forEach(p -> p.parse(parseContext)); + for (ChatQueryParser parser : chatQueryParsers) { + if (parser.accept(parseContext)) { + parser.parse(parseContext); + } + } for (ParseResultProcessor processor : parseResultProcessors) { if (processor.accept(parseContext)) { @@ -116,9 +120,11 @@ public class ChatQueryServiceImpl implements ChatQueryService { QueryResult queryResult = new QueryResult(); ExecuteContext executeContext = buildExecuteContext(chatExecuteReq); for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) { - queryResult = chatQueryExecutor.execute(executeContext); - if (queryResult != null) { - break; + if (chatQueryExecutor.accept(executeContext)) { + queryResult = chatQueryExecutor.execute(executeContext); + if (queryResult != null) { + break; + } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ChatModelServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ChatModelServiceImpl.java index 7cea5dae3..dd289a3db 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ChatModelServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ChatModelServiceImpl.java @@ -25,12 +25,12 @@ public class ChatModelServiceImpl extends ServiceImpl getChatModels(User user) { return list().stream().map(this::convert).filter(chatModel -> { - if (chatModel.isPublic() || user.isSuperAdmin() - || chatModel.getCreatedBy().equals(user.getName()) - || chatModel.getViewers().contains(user.getName())) { - return true; - } - return false; + if (chatModel.isPublic() || user.isSuperAdmin() + || chatModel.getCreatedBy().equals(user.getName()) + || chatModel.getViewers().contains(user.getName())) { + return true; + } + return false; }).collect(Collectors.toList()); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ChatModelController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ChatModelController.java index 3c08d5356..ae377600a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ChatModelController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ChatModelController.java @@ -53,7 +53,7 @@ public class ChatModelController { @RequestMapping("/getModelList") public List getModelList(HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return chatModelService.getChatModels(user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 4397f2280..f166b81db 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -96,8 +96,7 @@ public class DatabaseServiceImpl extends ServiceImpl