From 7f91993084a174c36b34c14cd8b8c6aa67630ccb Mon Sep 17 00:00:00 2001 From: williamhliu <137068196+williamhliu@users.noreply.github.com> Date: Tue, 25 Jun 2024 21:12:18 +0800 Subject: [PATCH 1/2] (improvement)(chat-sdk) add agentId to execute api and add PLAIN_TEXT query mode (#1223) --- .../chat-sdk/src/components/ChatItem/SqlItem.tsx | 2 +- .../chat-sdk/src/components/ChatItem/index.tsx | 2 +- .../chat-sdk/src/components/ChatMsg/index.tsx | 12 +++++++----- webapp/packages/chat-sdk/src/service/index.ts | 8 +++++++- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx b/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx index 8aa911055..7c898a610 100644 --- a/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx +++ b/webapp/packages/chat-sdk/src/components/ChatItem/SqlItem.tsx @@ -96,7 +96,7 @@ const SqlItem: React.FC = ({ setSqlType(sqlType === 's2SQL' ? '' : 's2SQL'); }} > - {queryMode === 'LLM_S2SQL' ? 'LLM' : 'Rule'}解析S2SQL + {queryMode === 'LLM_S2SQL' || queryMode === 'PLAIN_TEXT' ? 'LLM' : 'Rule'}解析S2SQL )} {sqlInfo.correctS2SQL && ( diff --git a/webapp/packages/chat-sdk/src/components/ChatItem/index.tsx b/webapp/packages/chat-sdk/src/components/ChatItem/index.tsx index f3b19b938..ee1a76c0e 100644 --- a/webapp/packages/chat-sdk/src/components/ChatItem/index.tsx +++ b/webapp/packages/chat-sdk/src/components/ChatItem/index.tsx @@ -123,7 +123,7 @@ const ChatItem: React.FC = ({ setExecuteLoading(true); } try { - const res: any = await chatExecute(msg, conversationId!, parseInfoValue); + const res: any = await chatExecute(msg, conversationId!, parseInfoValue, agentId); const valid = updateData(res); onMsgDataLoaded?.( { diff --git a/webapp/packages/chat-sdk/src/components/ChatMsg/index.tsx b/webapp/packages/chat-sdk/src/components/ChatMsg/index.tsx index 6ba6ce51b..8faaac0d7 100644 --- a/webapp/packages/chat-sdk/src/components/ChatMsg/index.tsx +++ b/webapp/packages/chat-sdk/src/components/ChatMsg/index.tsx @@ -86,11 +86,13 @@ const ChatMsg: React.FC = ({ const isMetricCard = (queryMode.includes('METRIC') || isDslMetricCard) && singleData; const isText = - columns.length === 1 && - columns[0].showType === 'CATEGORY' && - ((!queryMode.includes('METRIC') && !queryMode.includes('ENTITY')) || - queryMode === 'METRIC_INTERPRET') && - singleData; + queryMode === 'PLAIN_TEXT' || + (columns.length === 1 && + columns[0].showType === 'CATEGORY' && + ((!queryMode.includes('METRIC') && !queryMode.includes('ENTITY')) || + queryMode === 'METRIC_INTERPRET') && + singleData); + if (isText) { return MsgContentTypeEnum.TEXT; } diff --git a/webapp/packages/chat-sdk/src/service/index.ts b/webapp/packages/chat-sdk/src/service/index.ts index 60cf22cba..a00f84abd 100644 --- a/webapp/packages/chat-sdk/src/service/index.ts +++ b/webapp/packages/chat-sdk/src/service/index.ts @@ -61,9 +61,15 @@ export function chatParse( }); } -export function chatExecute(queryText: string, chatId: number, parseInfo: ChatContextType) { +export function chatExecute( + queryText: string, + chatId: number, + parseInfo: ChatContextType, + agentId?: number +) { return axios.post(`${prefix}/chat/query/execute`, { queryText, + agentId, chatId: chatId || DEFAULT_CHAT_ID, queryId: parseInfo.queryId, parseId: parseInfo.id, From c64aa624566ab6334561bc97e2a2ba209ceb0d65 Mon Sep 17 00:00:00 2001 From: Jun Zhang Date: Tue, 25 Jun 2024 21:16:02 +0800 Subject: [PATCH 2/2] (feature)(chat)Introduce new plain_text mode to allow users to talk to LLM directly. (#1224) --- .../chat/api/pojo/request/ChatExecuteReq.java | 1 + .../supersonic/chat/server/agent/Agent.java | 17 ++++++++ .../server/executor/PlainTextExecutor.java | 42 +++++++++++++++++++ .../chat/server/executor/SqlExecutor.java | 14 +++---- .../chat/server/parser/MultiTurnParser.java | 4 ++ .../chat/server/parser/NL2PluginParser.java | 4 ++ .../chat/server/parser/NL2SQLParser.java | 7 +--- .../chat/server/parser/PlainTextParser.java | 18 ++++++++ .../chat/server/pojo/ChatExecuteContext.java | 1 + .../processor/parse/EntityInfoProcessor.java | 3 +- .../server/service/impl/ChatServiceImpl.java | 10 +++-- .../headless/api/pojo/SemanticParseInfo.java | 2 +- .../headless/chat/query/QueryManager.java | 4 +- .../main/resources/META-INF/spring.factories | 6 ++- .../com/tencent/supersonic/chat/BaseTest.java | 4 +- 15 files changed, 112 insertions(+), 25 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java index 0080cef8e..58884567d 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatExecuteReq.java @@ -12,6 +12,7 @@ import lombok.NoArgsConstructor; @AllArgsConstructor public class ChatExecuteReq { private User user; + private Integer agentId; private Long queryId; private Integer chatId; private int parseId; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 7ebb4c3aa..4101de7bc 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -71,6 +71,10 @@ public class Agent extends RecordInfo { .collect(Collectors.toList()); } + public boolean containsPluginTool() { + return !CollectionUtils.isEmpty(getParserTools(AgentToolType.PLUGIN)); + } + public boolean containsLLMParserTool() { return !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_LLM)); } @@ -84,6 +88,19 @@ public class Agent extends RecordInfo { || !CollectionUtils.isEmpty(getParserTools(AgentToolType.NL2SQL_RULE)); } + public boolean containsAnyTool() { + Map map = JSONObject.parseObject(agentConfig, Map.class); + if (CollectionUtils.isEmpty(map)) { + return false; + } + List toolList = (List) map.get("tools"); + if (CollectionUtils.isEmpty(toolList)) { + return false; + } + + return true; + } + public Set getDataSetIds() { Set dataSetIds = getDataSetIds(null); if (containsAllModel(dataSetIds)) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java new file mode 100644 index 000000000..ac39df600 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -0,0 +1,42 @@ +package com.tencent.supersonic.chat.server.executor; + +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.S2ChatModelProvider; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.chat.ChatLanguageModel; +import dev.langchain4j.model.input.Prompt; +import dev.langchain4j.model.input.PromptTemplate; +import dev.langchain4j.model.output.Response; + +import java.util.Collections; + +public class PlainTextExecutor implements ChatExecutor { + + @Override + public QueryResult execute(ChatExecuteContext chatExecuteContext) { + if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) { + return null; + } + + Prompt prompt = PromptTemplate.from(chatExecuteContext.getQueryText()) + .apply(Collections.EMPTY_MAP); + + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId()); + + ChatLanguageModel chatLanguageModel = S2ChatModelProvider.provide(chatAgent.getLlmConfig()); + Response response = chatLanguageModel.generate(prompt.toUserMessage()); + + QueryResult result = new QueryResult(); + result.setQueryState(QueryState.SUCCESS); + result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode()); + result.setTextResult(response.content().text()); + + return result; + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index d4db2cffc..db173ed91 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.server.executor; -import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; import com.tencent.supersonic.chat.server.util.ResultFormatter; import com.tencent.supersonic.common.util.ContextUtils; @@ -15,16 +14,15 @@ public class SqlExecutor implements ChatExecutor { @SneakyThrows @Override public QueryResult execute(ChatExecuteContext chatExecuteContext) { - SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); - if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) { - return null; - } ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); QueryResult queryResult = chatQueryService.performExecution(executeQueryReq); - String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), - queryResult.getQueryResults()); - queryResult.setTextResult(textResult); + if (queryResult != null) { + String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), + queryResult.getQueryResults()); + queryResult.setTextResult(textResult); + } + return queryResult; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java index 392760289..e8e1e3783 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -52,6 +52,10 @@ public class MultiTurnParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + if (!chatParseContext.getAgent().containsAnyTool()) { + return; + } + ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java index 19608d2d8..305fdf01a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java @@ -15,6 +15,10 @@ public class NL2PluginParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + if (!chatParseContext.getAgent().containsPluginTool()) { + return; + } + pluginRecognizers.forEach(pluginRecognizer -> { pluginRecognizer.recognize(chatParseContext, parseResp); log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 90d2c4583..c0cd758d8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -23,14 +23,11 @@ public class NL2SQLParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { - if (!chatParseContext.enableNL2SQL()) { + if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) { return; } - if (checkSkip(parseResp)) { - return; - } - QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java new file mode 100644 index 000000000..f46fd4f04 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/PlainTextParser.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; + +public class PlainTextParser implements ChatParser { + @Override + public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + if (chatParseContext.getAgent().containsAnyTool()) { + return; + } + + SemanticParseInfo parseInfo = new SemanticParseInfo(); + parseInfo.setQueryMode("PLAIN_TEXT"); + parseResp.getSelectedParses().add(parseInfo); + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java index cc8b97f9b..4457d475e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java @@ -7,6 +7,7 @@ import lombok.Data; @Data public class ChatExecuteContext { private User user; + private Integer agentId; private Long queryId; private Integer chatId; private int parseId; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java index 3986e916e..9028c5d9a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/EntityInfoProcessor.java @@ -26,9 +26,10 @@ public class EntityInfoProcessor implements ParseResultProcessor { } selectedParses.forEach(parseInfo -> { String queryMode = parseInfo.getQueryMode(); - if (QueryManager.containsRuleQuery(queryMode)) { + if (QueryManager.containsRuleQuery(queryMode) || "PLAIN".equals(queryMode)) { return; } + //1. set entity info SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index 60edd1f1b..f7a393274 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -90,10 +90,14 @@ public class ChatServiceImpl implements ChatService { break; } } - for (ExecuteResultProcessor processor : executeResultProcessors) { - processor.process(chatExecuteContext, queryResult); + + if (queryResult != null) { + for (ExecuteResultProcessor processor : executeResultProcessors) { + processor.process(chatExecuteContext, queryResult); + } + saveQueryResult(chatExecuteReq, queryResult); } - saveQueryResult(chatExecuteReq, queryResult); + return queryResult; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index b9e7b9574..d1ff206aa 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -21,7 +21,7 @@ import java.util.TreeSet; public class SemanticParseInfo { private Integer id; - private String queryMode; + private String queryMode = "PLAIN_TEXT"; private SchemaElement dataSet; private Set metrics = new TreeSet<>(new SchemaNameLengthComparator()); private Set dimensions = new LinkedHashSet(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java index 91de9690b..c1d835923 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/QueryManager.java @@ -43,12 +43,12 @@ public class QueryManager { private static SemanticQuery getSemanticQuery(String queryMode, SemanticQuery semanticQuery) { if (Objects.isNull(semanticQuery)) { - throw new RuntimeException("no supported queryMode :" + queryMode); + return null; } try { return semanticQuery.getClass().getDeclaredConstructor().newInstance(); } catch (Exception e) { - throw new RuntimeException("no supported queryMode :" + queryMode); + return null; } } diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index ad7b6d458..26bfc7a22 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -56,11 +56,13 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\ com.tencent.supersonic.chat.server.parser.ChatParser=\ com.tencent.supersonic.chat.server.parser.NL2PluginParser, \ com.tencent.supersonic.chat.server.parser.MultiTurnParser,\ - com.tencent.supersonic.chat.server.parser.NL2SQLParser + com.tencent.supersonic.chat.server.parser.NL2SQLParser,\ + com.tencent.supersonic.chat.server.parser.PlainTextParser com.tencent.supersonic.chat.server.executor.ChatExecutor=\ com.tencent.supersonic.chat.server.executor.PluginExecutor, \ - com.tencent.supersonic.chat.server.executor.SqlExecutor + com.tencent.supersonic.chat.server.executor.SqlExecutor,\ + com.tencent.supersonic.chat.server.executor.PlainTextExecutor com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\ com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index 12c3517a7..db36e0d48 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -5,7 +5,6 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.ChatService; -import com.tencent.supersonic.chat.server.service.ConfigService; import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; @@ -31,8 +30,6 @@ public class BaseTest extends BaseApplication { @Autowired protected ChatService chatService; @Autowired - protected ConfigService configService; - @Autowired protected AgentService agentService; protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception { @@ -61,6 +58,7 @@ public class BaseTest extends BaseApplication { .queryText(parseResp.getQueryText()) .user(DataUtils.getUser()) .parseId(parseInfo.getId()) + .agentId(agentId) .queryId(parseResp.getQueryId()) .saveAnswer(false) .build();