From d4cc53acae6f65f7d6cd71032467ea3fbf30f9c8 Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Tue, 25 Jun 2024 21:14:19 +0800 Subject: [PATCH] (feature)(chat)Introduce new plain_text mode to allow users to talk to LLM directly. --- .../chat/api/pojo/request/ChatExecuteReq.java | 1 + .../supersonic/chat/server/agent/Agent.java | 17 ++++++++ .../server/executor/PlainTextExecutor.java | 42 +++++++++++++++++++ .../chat/server/executor/SqlExecutor.java | 14 +++---- .../chat/server/parser/MultiTurnParser.java | 4 ++ .../chat/server/parser/NL2PluginParser.java | 4 ++ .../chat/server/parser/NL2SQLParser.java | 7 +--- .../chat/server/parser/PlainTextParser.java | 18 ++++++++ .../chat/server/pojo/ChatExecuteContext.java | 1 + .../processor/parse/EntityInfoProcessor.java | 3 +- .../server/service/impl/ChatServiceImpl.java | 10 +++-- .../headless/api/pojo/SemanticParseInfo.java | 2 +- .../headless/chat/query/QueryManager.java | 4 +- .../main/resources/META-INF/spring.factories | 6 ++- .../com/tencent/supersonic/chat/BaseTest.java | 4 +- 15 files changed, 112 insertions(+), 25 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java index 0080cef8e..58884567d 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java @@ -12,6 +12,7 @@ import lombok.NoArgsConstructor; @AllArgsConstructor public class ChatExecuteReq { private User user; + private Integer agentId; private Long queryId; private Integer chatId; private int parseId; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 7ebb4c3aa..4101de7bc 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -71,6 +71,10 @@ public class Agent extends RecordInfo { .collect(Collectors.toList()); } + public boolean containsPluginTool() { + return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN)); + } + public boolean containsLLMParserTool() { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); } @@ -84,6 +88,19 @@ public class Agent extends RecordInfo { || !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); } + public boolean containsAnyTool() { + Map map = JSONObject.parseObject(agentConfig, Map.class); + if (CollectionUtils.isEmpty(map)) { + return false; + } + List toolList = (List) map.get("tools"); + if (CollectionUtils.isEmpty(toolList)) { + return false; + } + + return true; + } + public Set getDataSetIds() { Set dataSetIds = getDataSetIds(null); if (containsAllModel(dataSetIds)) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java new file mode 100644 index 000000000..ac39df600 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -0,0 +1,42 @@ +package com.tencent.supersonic.chat.server.executor; + +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.S2ChatModelProvider; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.Response; + +import java.util.Collections; + +public class PlainTextExecutor implements ChatExecutor { + + @Override + public QueryResult execute(ChatExecuteContext chatExecuteContext) { + if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) { + return null; + } + + Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText()) + .apply(Collections.EMPTY_MAP); + + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); + + ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); + Response response = chatLanguageModel.generate(prompt.toUserMessage()); + + QueryResult result = new QueryResult(); + result.setQueryState(QueryState.SUCCESS); + result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode()); + result.setTextResult(response.content().text()); + + return result; + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index d4db2cffc..db173ed91 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.server.executor; -import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; import com.tencent.supersonic.chat.server.util.ResultFormatter; import com.tencent.supersonic.common.util.ContextUtils; @@ -15,16 +14,15 @@ public class SqlExecutor implements ChatExecutor { @SneakyThrows @Override public QueryResult execute(ChatExecuteContext chatExecuteContext) { - SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); - if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) { - return null; - } ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); QueryResult queryResult = chatQueryService.performExecution(executeQueryReq); - String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), - queryResult.getQueryResults()); - queryResult.setTextResult(textResult); + if (queryResult != null) { + String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), + queryResult.getQueryResults()); + queryResult.setTextResult(textResult); + } + return queryResult; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java index 392760289..e8e1e3783 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -52,6 +52,10 @@ public class MultiTurnParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + if (!chatParseContext.getAgent().containsAnyTool()) { + return; + } + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java index 19608d2d8..305fdf01a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java @@ -15,6 +15,10 @@ public class NL2PluginParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + if (!chatParseContext.getAgent().containsPluginTool()) { + return; + } + pluginRecognizers.forEach(pluginRecognizer -> { pluginRecognizer.recognize(chatParseContext, parseResp); log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 90d2c4583..c0cd758d8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -23,14 +23,11 @@ public class NL2SQLParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { - if (!chatParseContext.enableNL2SQL()) { + if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) { return; } - if (checkSkip(parseResp)) { - return; - } - QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java new file mode 100644 index 000000000..f46fd4f04 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; + +public class PlainTextParser implements ChatParser { + @Override + public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + if (chatParseContext.getAgent().containsAnyTool()) { + return; + } + + SemanticParseInfo parseInfo = new SemanticParseInfo(); + parseInfo.setQueryMode("PLAIN_TEXT"); + parseResp.getSelectedParses().add(parseInfo); + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java index cc8b97f9b..4457d475e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java @@ -7,6 +7,7 @@ import lombok.Data; @Data public class ChatExecuteContext { private User user; + private Integer agentId; private Long queryId; private Integer chatId; private int parseId; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java index 3986e916e..9028c5d9a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java @@ -26,9 +26,10 @@ public class EntityInfoProcessor implements ParseResultProcessor { } selectedParses.forEach(parseInfo -> { String queryMode = parseInfo.getQueryMode(); - if (QueryManager.containsRuleQuery(queryMode)) { + if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) { return; } + //1. set entity info SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index 60edd1f1b..f7a393274 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -90,10 +90,14 @@ public class ChatServiceImpl implements ChatService { break; } } - for (ExecuteResultProcessor processor : executeResultProcessors) { - processor.process(chatExecuteContext, queryResult); + + if (queryResult != null) { + for (ExecuteResultProcessor processor : executeResultProcessors) { + processor.process(chatExecuteContext, queryResult); + } + saveQueryResult(chatExecuteReq, queryResult); } - saveQueryResult(chatExecuteReq, queryResult); + return queryResult; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index b9e7b9574..d1ff206aa 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -21,7 +21,7 @@ import java.util.TreeSet; public class SemanticParseInfo { private Integer id; - private String queryMode; + private String queryMode = "PLAIN_TEXT"; private SchemaElement dataSet; private Set metrics = new TreeSet<>(new SchemaNameLengthComparator()); private Set dimensions = new LinkedHashSet(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java index 91de9690b..c1d835923 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java @@ -43,12 +43,12 @@ public class QueryManager { private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) { if (Objects.isNull(semanticQuery)) { - throw new RuntimeException("no supported queryMode :" + queryMode); + return null; } try { return semanticQuery.getClass().getDeclaredConstructor().newInstance(); } catch (Exception e) { - throw new RuntimeException("no supported queryMode :" + queryMode); + return null; } } diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index ad7b6d458..26bfc7a22 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -56,11 +56,13 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\ com.tencent.supersonic.chat.server.parser.ChatParser=\ com.tencent.supersonic.chat.server.parser.NL2PluginParser, \ com.tencent.supersonic.chat.server.parser.MultiTurnParser,\ - com.tencent.supersonic.chat.server.parser.NL2SQLParser + com.tencent.supersonic.chat.server.parser.NL2SQLParser,\ + com.tencent.supersonic.chat.server.parser.PlainTextParser com.tencent.supersonic.chat.server.executor.ChatExecutor=\ com.tencent.supersonic.chat.server.executor.PluginExecutor, \ - com.tencent.supersonic.chat.server.executor.SqlExecutor + com.tencent.supersonic.chat.server.executor.SqlExecutor,\ + com.tencent.supersonic.chat.server.executor.PlainTextExecutor com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\ com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index 12c3517a7..db36e0d48 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatService; -import com.tencent.supersonic.chat.server.service.ConfigService; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; @@ -31,8 +30,6 @@ public class BaseTest extends BaseApplication { @Autowired protected ChatService chatService; @Autowired - protected ConfigService configService; - @Autowired protected AgentService agentService; protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception { @@ -61,6 +58,7 @@ public class BaseTest extends BaseApplication { .queryText(parseResp.getQueryText()) .user(DataUtils.getUser()) .parseId(parseInfo.getId()) + .agentId(agentId) .queryId(parseResp.getQueryId()) .saveAnswer(false) .build();