From 5519a507ae9a719755d56fb62e921e4fe3bed4ae Mon Sep 17 00:00:00 2001 From: LXW <1264174498@qq.com> Date: Thu, 14 Mar 2024 18:37:09 +0800 Subject: [PATCH] (improvement)(Chat) Integrate chat with plugin recognizer and parse result processor (#820) Co-authored-by: jolunoluo --- .../chat/api/pojo/request/ChatParseReq.java | 2 + .../chat/server/executor/ChatExecutor.java | 10 ++ .../chat/server/executor/PluginExecutor.java | 21 +++++ .../chat/server/executor/SqlExecutor.java | 38 ++++++++ .../chat/server/parser/ChatParser.java | 10 ++ .../chat/server/parser/Text2PluginParser.java | 20 ++++ .../chat/server/parser/Text2SqlParser.java | 22 +++++ .../chat/server/plugin/PluginManager.java | 17 ++-- .../server/plugin/PluginQueryManager.java | 24 +++++ .../plugin/build/PluginSemanticQuery.java | 16 ++-- .../plugin/build/webpage/WebPageQuery.java | 34 ++++--- .../build/webservice/WebServiceQuery.java | 37 +++++--- .../recall/function/FunctionCallConfig.java | 15 --- .../plugin/recall/function/FunctionFiled.java | 13 --- .../function/FunctionPromptGenerator.java | 75 --------------- .../plugin/recall/function/FunctionReq.java | 17 ---- .../plugin/recall/function/FunctionResp.java | 10 -- .../plugin/recall/function/Parameters.java | 18 ---- .../PluginRecognizer.java} | 50 +++++----- .../embedding/EmbeddingRecallRecognizer.java} | 23 +++-- .../embedding/RecallRetrieval.java | 2 +- .../embedding/RecallRetrievalResp.java | 2 +- .../chat/server/pojo/ChatExecuteContext.java | 16 ++++ .../chat/server/pojo/ChatParseContext.java | 17 ++++ .../processor/parse/ParseResultProcessor.java | 4 +- .../parse/QueryRecommendProcessor.java | 15 ++- .../processor/parse/RespBuildProcessor.java | 28 ++++++ .../server/service/impl/ChatServiceImpl.java | 91 ++++++++++--------- .../chat/server/util/ComponentFactory.java | 30 +++++- .../chat/server/util/QueryReqConverter.java | 30 ++++++ .../headless/api/pojo/response/MapResp.java | 13 +++ .../rest/api/ChatQueryApiController.java | 2 - .../server/service/ChatQueryService.java | 3 + .../service/impl/ChatQueryServiceImpl.java | 13 +++ .../main/resources/META-INF/spring.factories | 15 +++ 35 files changed, 463 insertions(+), 290 deletions(-) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatExecutor.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatParser.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2PluginParser.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2SqlParser.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionCallConfig.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionFiled.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionPromptGenerator.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionReq.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionResp.java delete mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/Parameters.java rename chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/{recall/PluginParser.java => recognize/PluginRecognizer.java} (68%) rename chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/{recall/embedding/EmbeddingRecallParser.java => recognize/embedding/EmbeddingRecallRecognizer.java} (82%) rename chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/{recall => recognize}/embedding/RecallRetrieval.java (74%) rename chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/{recall => recognize}/embedding/RecallRetrievalResp.java (69%) create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatParseContext.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/RespBuildProcessor.java create mode 100644 chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java create mode 100644 headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapResp.java diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java index 2307043d6..e35613fc1 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/ChatParseReq.java @@ -1,6 +1,7 @@ package com.tencent.supersonic.chat.api.pojo.request; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import lombok.Data; @@ -12,5 +13,6 @@ public class ChatParseReq { private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; + private SchemaMapInfo mapInfo = new SchemaMapInfo(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatExecutor.java new file mode 100644 index 000000000..98fa14a78 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/ChatExecutor.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.chat.server.executor; + +import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; + +public interface ChatExecutor { + + QueryResult execute(ChatExecuteContext chatExecuteContext); + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java new file mode 100644 index 000000000..b584bc92d --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PluginExecutor.java @@ -0,0 +1,21 @@ +package com.tencent.supersonic.chat.server.executor; + +import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; +import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery; +import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; + +public class PluginExecutor implements ChatExecutor { + + @Override + public QueryResult execute(ChatExecuteContext chatExecuteContext) { + SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); + if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) { + return null; + } + PluginSemanticQuery query = PluginQueryManager.getPluginQuery(parseInfo.getQueryMode()); + return query.build(); + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java new file mode 100644 index 000000000..cb82ad132 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -0,0 +1,38 @@ +package com.tencent.supersonic.chat.server.executor; + +import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; +import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.server.service.ChatQueryService; +import lombok.SneakyThrows; + +public class SqlExecutor implements ChatExecutor { + + @SneakyThrows + @Override + public QueryResult execute(ChatExecuteContext chatExecuteContext) { + SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); + if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) { + return null; + } + ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext); + ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); + return chatQueryService.performExecution(executeQueryReq); + } + + private ExecuteQueryReq buildExecuteReq(ChatExecuteContext chatExecuteContext) { + SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo(); + return ExecuteQueryReq.builder() + .queryId(chatExecuteContext.getQueryId()) + .chatId(chatExecuteContext.getChatId()) + .queryText(chatExecuteContext.getQueryText()) + .parseInfo(parseInfo) + .saveAnswer(chatExecuteContext.isSaveAnswer()) + .user(chatExecuteContext.getUser()) + .build(); + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatParser.java new file mode 100644 index 000000000..8c50e3745 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ChatParser.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; + +public interface ChatParser { + + void parse(ChatParseContext chatParseContext, ParseResp parseResp); + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2PluginParser.java new file mode 100644 index 000000000..805b7c6ef --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2PluginParser.java @@ -0,0 +1,20 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.chat.server.util.ComponentFactory; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import java.util.List; + +public class Text2PluginParser implements ChatParser { + + private final List pluginRecognizers = ComponentFactory.getPluginRecognizers(); + + @Override + public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + pluginRecognizers.forEach(pluginRecognizer -> { + pluginRecognizer.recognize(chatParseContext, parseResp); + }); + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2SqlParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2SqlParser.java new file mode 100644 index 000000000..1faca2ddf --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/Text2SqlParser.java @@ -0,0 +1,22 @@ +package com.tencent.supersonic.chat.server.parser; + +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.chat.server.util.QueryReqConverter; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import com.tencent.supersonic.headless.server.service.ChatQueryService; + +public class Text2SqlParser implements ChatParser { + + @Override + public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class); + ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq); + if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) { + parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); + } + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java index de6ed99aa..76f299ba7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java @@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.plugin; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.AgentToolType; import com.tencent.supersonic.chat.server.agent.PluginTool; @@ -12,6 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent; import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent; import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.common.config.EmbeddingConfig; @@ -26,7 +26,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.headless.core.pojo.QueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.tuple.Pair; @@ -53,10 +52,10 @@ public class PluginManager { private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore(); - public static List getPluginAgentCanSupport(ChatParseReq chatParseReq) { + public static List getPluginAgentCanSupport(ChatParseContext chatParseContext) { PluginService pluginService = ContextUtils.getBean(PluginService.class); AgentService agentService = ContextUtils.getBean(AgentService.class); - Agent agent = agentService.getAgent(chatParseReq.getAgentId()); + Agent agent = agentService.getAgent(chatParseContext.getAgentId()); List plugins = pluginService.getPluginList(); if (Objects.isNull(agent)) { @@ -199,9 +198,9 @@ public class PluginManager { return String.valueOf(Integer.parseInt(id) / 1000); } - public static Pair> resolve(Plugin plugin, QueryContext queryContext) { - SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); - Set pluginMatchedModel = getPluginMatchedModel(plugin, queryContext); + public static Pair> resolve(Plugin plugin, ChatParseContext chatParseContext) { + SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo(); + Set pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext); if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) { return Pair.of(false, Sets.newHashSet()); } @@ -267,8 +266,8 @@ public class PluginManager { .collect(Collectors.toList()); } - private static Set getPluginMatchedModel(Plugin plugin, QueryContext queryContext) { - Set matchedDataSets = queryContext.getMapInfo().getMatchedDataSetInfos(); + private static Set getPluginMatchedModel(Plugin plugin, ChatParseContext chatParseContext) { + Set matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos(); if (plugin.isContainsAllModel()) { return Sets.newHashSet(plugin.getDefaultMode()); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java new file mode 100644 index 000000000..4b649b647 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginQueryManager.java @@ -0,0 +1,24 @@ +package com.tencent.supersonic.chat.server.plugin; + +import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery; + +import java.util.HashMap; +import java.util.Map; + +public class PluginQueryManager { + + private static Map pluginQueries = new HashMap<>(); + + public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) { + pluginQueries.put(queryMode, pluginSemanticQuery); + } + + public static boolean isPluginQuery(String queryMode) { + return pluginQueries.containsKey(queryMode); + } + + public static PluginSemanticQuery getPluginQuery(String queryMode) { + return pluginQueries.get(queryMode); + } + +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java index f23647e7b..26679a2af 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java @@ -1,14 +1,13 @@ package com.tencent.supersonic.chat.server.plugin.build; import com.google.common.collect.Lists; -import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; -import com.tencent.supersonic.headless.api.pojo.SemanticSchema; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; -import com.tencent.supersonic.headless.core.chat.query.BaseSemanticQuery; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @@ -17,12 +16,11 @@ import java.util.List; import java.util.Map; @Slf4j -public abstract class PluginSemanticQuery extends BaseSemanticQuery { +public abstract class PluginSemanticQuery { - @Override - public void initS2Sql(SemanticSchema semanticSchema, User user) { + protected SemanticParseInfo parseInfo; - } + public abstract QueryResult build(); private Map getFilterMap(PluginParseResult pluginParseResult) { Map map = new HashMap<>(); @@ -91,4 +89,8 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery { return webBaseResult; } + public void setParseInfo(SemanticParseInfo parseInfo) { + this.parseInfo = parseInfo; + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java index 12cabed9d..9597ca547 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java @@ -2,15 +2,18 @@ package com.tencent.supersonic.chat.server.plugin.build.webpage; import com.tencent.supersonic.chat.server.plugin.Plugin; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; +import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery; import com.tencent.supersonic.chat.server.plugin.build.WebBase; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.core.chat.query.QueryManager; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.parser.SqlParseException; import org.springframework.stereotype.Component; +import java.util.Map; + @Slf4j @Component public class WebPageQuery extends PluginSemanticQuery { @@ -18,17 +21,7 @@ public class WebPageQuery extends PluginSemanticQuery { public static String QUERY_MODE = "WEB_PAGE"; public WebPageQuery() { - QueryManager.register(this); - } - - @Override - public String getQueryMode() { - return QUERY_MODE; - } - - @Override - public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException { - return null; + PluginQueryManager.register(QUERY_MODE, this); } protected WebPageResp buildResponse(PluginParseResult pluginParseResult) { @@ -43,4 +36,17 @@ public class WebPageQuery extends PluginSemanticQuery { return webPageResponse; } + @Override + public QueryResult build() { + QueryResult queryResult = new QueryResult(); + queryResult.setQueryMode(QUERY_MODE); + Map properties = parseInfo.getProperties(); + PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)), + PluginParseResult.class); + WebPageResp webPageResponse = buildResponse(pluginParseResult); + queryResult.setResponse(webPageResponse); + queryResult.setQueryState(QueryState.SUCCESS); + return queryResult; + } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java index cb479bd5b..71d3fa12b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java @@ -3,15 +3,17 @@ package com.tencent.supersonic.chat.server.plugin.build.webservice; import com.alibaba.fastjson.JSON; import com.tencent.supersonic.chat.server.plugin.Plugin; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; +import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.build.ParamOption; import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery; import com.tencent.supersonic.chat.server.plugin.build.WebBase; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; -import com.tencent.supersonic.headless.core.chat.query.QueryManager; +import com.tencent.supersonic.headless.api.pojo.response.QueryResult; +import com.tencent.supersonic.headless.api.pojo.response.QueryState; import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.parser.SqlParseException; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -36,17 +38,30 @@ public class WebServiceQuery extends PluginSemanticQuery { private RestTemplate restTemplate; public WebServiceQuery() { - QueryManager.register(this); + PluginQueryManager.register(QUERY_MODE, this); } @Override - public String getQueryMode() { - return QUERY_MODE; - } - - @Override - public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException { - return null; + public QueryResult build() { + QueryResult queryResult = new QueryResult(); + queryResult.setQueryMode(QUERY_MODE); + Map properties = parseInfo.getProperties(); + PluginParseResult pluginParseResult = JsonUtil.toObject( + JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class); + WebServiceResp webServiceResponse = buildResponse(pluginParseResult); + Object object = webServiceResponse.getResult(); + // in order to show webServiceQuery result int frontend conveniently, + // webServiceResponse result format is consistent with queryByStruct result. + log.info("webServiceResponse result:{}", JsonUtil.toString(object)); + try { + Map data = JsonUtil.toMap(JsonUtil.toString(object), String.class, Object.class); + queryResult.setQueryResults((List>) data.get("resultList")); + queryResult.setQueryColumns((List) data.get("columns")); + queryResult.setQueryState(QueryState.SUCCESS); + } catch (Exception e) { + log.info("webServiceResponse result has an exception:{}", e.getMessage()); + } + return queryResult; } protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionCallConfig.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionCallConfig.java deleted file mode 100644 index abab66945..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionCallConfig.java +++ /dev/null @@ -1,15 +0,0 @@ -package com.tencent.supersonic.chat.server.plugin.recall.function; - -import lombok.Data; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Configuration; - -@Configuration -@Data -public class FunctionCallConfig { - @Value("${functionCall.url:}") - private String url; - - @Value("${funtionCall.plugin.select.path:/plugin_selection}") - private String pluginSelectPath; -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionFiled.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionFiled.java deleted file mode 100644 index 6bbe3b568..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionFiled.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.tencent.supersonic.chat.server.plugin.recall.function; - - -import lombok.Data; - -@Data -public class FunctionFiled { - - private String type; - - private String description; - -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionPromptGenerator.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionPromptGenerator.java deleted file mode 100644 index 576e16b0f..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionPromptGenerator.java +++ /dev/null @@ -1,75 +0,0 @@ -package com.tencent.supersonic.chat.server.plugin.recall.function; - - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.core.chat.parser.llm.InputFormat; -import dev.langchain4j.model.chat.ChatLanguageModel; -import lombok.extern.slf4j.Slf4j; -import org.springframework.stereotype.Component; - -import java.util.List; -import java.util.stream.Collectors; - -@Component -@Slf4j -public class FunctionPromptGenerator { - - public String generateFunctionCallPrompt(String queryText, List toolConfigList) { - List toolExplainList = toolConfigList.stream() - .map(this::constructPluginPrompt) - .collect(Collectors.toList()); - String functionList = String.join(InputFormat.SEPERATOR, toolExplainList); - return constructTaskPrompt(queryText, functionList); - } - - public String constructPluginPrompt(PluginParseConfig parseConfig) { - String toolName = parseConfig.getName(); - String toolDescription = parseConfig.getDescription(); - List toolExamples = parseConfig.getExamples(); - - StringBuilder prompt = new StringBuilder(); - prompt.append("【工具名称】\n").append(toolName).append("\n"); - prompt.append("【工具描述】\n").append(toolDescription).append("\n"); - prompt.append("【工具适用问题示例】\n"); - for (String example : toolExamples) { - prompt.append(example).append("\n"); - } - return prompt.toString(); - } - - public String constructTaskPrompt(String queryText, String functionList) { - String instruction = String.format("问题为:%s\n请根据问题和工具的描述,选择对应的工具,完成任务。" - + "请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据)," - + "并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘", queryText); - - return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction); - } - - public FunctionResp requestFunction(FunctionReq functionReq) { - - FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class); - - ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class); - String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(), - functionReq.getPluginConfigs()); - String response = chatLanguageModel.generate(functionCallPrompt); - return functionCallParse(response); - } - - public static FunctionResp functionCallParse(String llmOutput) { - try { - ObjectMapper objectMapper = new ObjectMapper(); - JsonNode jsonNode = objectMapper.readTree(llmOutput); - String selectedTool = jsonNode.get("选择工具").asText(); - FunctionResp resp = new FunctionResp(); - resp.setToolSelection(selectedTool); - return resp; - } catch (Exception e) { - log.error("", e); - } - return null; - } -} \ No newline at end of file diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionReq.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionReq.java deleted file mode 100644 index ca6f05973..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionReq.java +++ /dev/null @@ -1,17 +0,0 @@ -package com.tencent.supersonic.chat.server.plugin.recall.function; - -import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; -import lombok.Builder; -import lombok.Data; - -import java.util.List; - -@Data -@Builder -public class FunctionReq { - - private String queryText; - - private List pluginConfigs; - -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionResp.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionResp.java deleted file mode 100644 index 7fe0a5b59..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/FunctionResp.java +++ /dev/null @@ -1,10 +0,0 @@ -package com.tencent.supersonic.chat.server.plugin.recall.function; - -import lombok.Data; - -@Data -public class FunctionResp { - - private String toolSelection; - -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/Parameters.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/Parameters.java deleted file mode 100644 index c78c668cb..000000000 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/function/Parameters.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.tencent.supersonic.chat.server.plugin.recall.function; - -import lombok.Data; - -import java.util.List; -import java.util.Map; - -@Data -public class Parameters { - - //default: object - private String type = "object"; - - private Map properties; - - private List required; - -} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java similarity index 68% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/PluginParser.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index aa0e50ed2..95f5446e8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -1,13 +1,12 @@ -package com.tencent.supersonic.chat.server.plugin.recall; +package com.tencent.supersonic.chat.server.plugin.recognize; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.plugin.Plugin; import com.tencent.supersonic.chat.server.plugin.PluginManager; import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; -import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; @@ -15,7 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; -import com.tencent.supersonic.headless.core.pojo.QueryContext; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import org.springframework.util.CollectionUtils; import java.util.HashMap; @@ -26,64 +25,58 @@ import java.util.Set; /** * PluginParser defines the basic process and common methods for recalling plugins. */ -public abstract class PluginParser { +public abstract class PluginRecognizer { - public void parse(ChatParseReq chatParseReq) { - if (!checkPreCondition(chatParseReq)) { + public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) { + if (!checkPreCondition(chatParseContext)) { return; } - PluginRecallResult pluginRecallResult = recallPlugin(chatParseReq); + PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext); if (pluginRecallResult == null) { return; } - buildQuery(chatParseReq, pluginRecallResult); + buildQuery(chatParseContext, parseResp, pluginRecallResult); } - public abstract boolean checkPreCondition(ChatParseReq chatParseReq); + public abstract boolean checkPreCondition(ChatParseContext chatParseContext); - public abstract PluginRecallResult recallPlugin(ChatParseReq chatParseReq); + public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext); - public void buildQuery(ChatParseReq chatParseReq, PluginRecallResult pluginRecallResult) { + public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp, + PluginRecallResult pluginRecallResult) { Plugin plugin = pluginRecallResult.getPlugin(); Set dataSetIds = pluginRecallResult.getDataSetIds(); if (plugin.isContainsAllModel()) { dataSetIds = Sets.newHashSet(-1L); } for (Long dataSetId : dataSetIds) { - //todo - PluginSemanticQuery pluginQuery = null; SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin, - null, pluginRecallResult.getDistance()); - semanticParseInfo.setQueryMode(pluginQuery.getQueryMode()); + chatParseContext, pluginRecallResult.getDistance()); + semanticParseInfo.setQueryMode(plugin.getType()); semanticParseInfo.setScore(pluginRecallResult.getScore()); - pluginQuery.setParseInfo(semanticParseInfo); - //chatParseReq.getCandidateQueries().add(pluginQuery); + parseResp.getSelectedParses().add(semanticParseInfo); } } - protected List getPluginList(ChatParseReq chatParseReq) { - return PluginManager.getPluginAgentCanSupport(chatParseReq); + protected List getPluginList(ChatParseContext chatParseContext) { + return PluginManager.getPluginAgentCanSupport(chatParseContext); } protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin, - QueryContext queryContext, double distance) { - List schemaElementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId); - QueryFilters queryFilters = queryContext.getQueryFilters(); - if (dataSetId == null && !CollectionUtils.isEmpty(plugin.getDataSetList())) { - dataSetId = plugin.getDataSetList().get(0); - } + ChatParseContext chatParseContext, double distance) { + List schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId); + QueryFilters queryFilters = chatParseContext.getQueryFilters(); if (schemaElementMatches == null) { schemaElementMatches = Lists.newArrayList(); } SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); semanticParseInfo.setElementMatches(schemaElementMatches); - semanticParseInfo.setDataSet(queryContext.getSemanticSchema().getDataSet(dataSetId)); Map properties = new HashMap<>(); PluginParseResult pluginParseResult = new PluginParseResult(); pluginParseResult.setPlugin(plugin); pluginParseResult.setQueryFilters(queryFilters); pluginParseResult.setDistance(distance); - pluginParseResult.setQueryText(queryContext.getQueryText()); + pluginParseResult.setQueryText(chatParseContext.getQueryText()); properties.put(Constants.CONTEXT, pluginParseResult); properties.put("type", "plugin"); properties.put("name", plugin.getName()); @@ -111,4 +104,5 @@ public abstract class PluginParser { semanticParseInfo.getDimensionFilters().add(queryFilter); }); } + } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/EmbeddingRecallParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java similarity index 82% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/EmbeddingRecallParser.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java index ab3b045eb..286862cd9 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/EmbeddingRecallParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java @@ -1,12 +1,12 @@ -package com.tencent.supersonic.chat.server.plugin.recall.embedding; +package com.tencent.supersonic.chat.server.plugin.recognize.embedding; import com.google.common.collect.Lists; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.server.plugin.ParseMode; import com.tencent.supersonic.chat.server.plugin.Plugin; import com.tencent.supersonic.chat.server.plugin.PluginManager; import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; -import com.tencent.supersonic.chat.server.plugin.recall.PluginParser; +import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.embedding.Retrieval; @@ -28,32 +28,31 @@ import java.util.stream.Collectors; * EmbeddingRecallParser is an implementation of a recall plugin based on Embedding */ @Slf4j -public class EmbeddingRecallParser extends PluginParser { +public class EmbeddingRecallRecognizer extends PluginRecognizer { - public boolean checkPreCondition(ChatParseReq chatParseReq) { + public boolean checkPreCondition(ChatParseContext chatParseContext) { EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class); if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) { return false; } - List plugins = getPluginList(chatParseReq); + List plugins = getPluginList(chatParseContext); return !CollectionUtils.isEmpty(plugins); } - public PluginRecallResult recallPlugin(ChatParseReq chatParseReq) { - String text = chatParseReq.getQueryText(); + public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) { + String text = chatParseContext.getQueryText(); List embeddingRetrievals = embeddingRecall(text); if (CollectionUtils.isEmpty(embeddingRetrievals)) { return null; } - List plugins = getPluginList(chatParseReq); + List plugins = getPluginList(chatParseContext); Map pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); for (Retrieval embeddingRetrieval : embeddingRetrievals) { Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); if (plugin == null) { continue; } - //todo - Pair> pair = PluginManager.resolve(plugin, null); + Pair> pair = PluginManager.resolve(plugin, chatParseContext); log.info("embedding plugin resolve: {}", pair); if (pair.getLeft()) { Set dataSetList = pair.getRight(); @@ -62,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser { } plugin.setParseMode(ParseMode.EMBEDDING_RECALL); double distance = embeddingRetrieval.getDistance(); - double score = chatParseReq.getQueryText().length() * (1 - distance); + double score = chatParseContext.getQueryText().length() * (1 - distance); return PluginRecallResult.builder() .plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build(); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/RecallRetrieval.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/RecallRetrieval.java similarity index 74% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/RecallRetrieval.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/RecallRetrieval.java index c347eeabd..02a808051 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/RecallRetrieval.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/RecallRetrieval.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.server.plugin.recall.embedding; +package com.tencent.supersonic.chat.server.plugin.recognize.embedding; import lombok.Data; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/RecallRetrievalResp.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/RecallRetrievalResp.java similarity index 69% rename from chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/RecallRetrievalResp.java rename to chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/RecallRetrievalResp.java index 7c0de6ac5..ae714b337 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recall/embedding/RecallRetrievalResp.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/RecallRetrievalResp.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.server.plugin.recall.embedding; +package com.tencent.supersonic.chat.server.plugin.recognize.embedding; import lombok.Data; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java new file mode 100644 index 000000000..cc8b97f9b --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatExecuteContext.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.chat.server.pojo; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import lombok.Data; + +@Data +public class ChatExecuteContext { + private User user; + private Long queryId; + private Integer chatId; + private int parseId; + private String queryText; + private boolean saveAnswer; + private SemanticParseInfo parseInfo; +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatParseContext.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatParseContext.java new file mode 100644 index 000000000..74486128c --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/pojo/ChatParseContext.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.server.pojo; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; +import lombok.Data; + +@Data +public class ChatParseContext { + private String queryText; + private Integer chatId; + private Integer agentId; + private User user; + private QueryFilters queryFilters; + private boolean saveAnswer = true; + private SchemaMapInfo mapInfo = new SchemaMapInfo(); +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java index ead11facf..d5e289dac 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseResultProcessor.java @@ -1,10 +1,10 @@ package com.tencent.supersonic.chat.server.processor.parse; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; public interface ParseResultProcessor { - void process(ParseResp parseResp, ChatParseReq chatParseReq); + void process(ChatParseContext chatParseContext, ParseResp parseResp); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index 8aab0b8fd..46c373a0d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -7,12 +7,10 @@ import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResp; -import com.tencent.supersonic.headless.core.pojo.ChatContext; -import com.tencent.supersonic.headless.core.pojo.QueryContext; -import com.tencent.supersonic.headless.server.processor.ResultProcessor; import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @@ -26,18 +24,17 @@ import java.util.stream.Collectors; * MetricRecommendProcessor fills recommended query based on embedding similarity. */ @Slf4j -public class QueryRecommendProcessor implements ResultProcessor { +public class QueryRecommendProcessor implements ParseResultProcessor { @Override - public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) { - CompletableFuture.runAsync(() -> doProcess(parseResp, queryContext)); + public void process(ChatParseContext chatParseContext, ParseResp parseResp) { + CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext)); } @SneakyThrows - private void doProcess(ParseResp parseResp, QueryContext queryContext) { + private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) { Long queryId = parseResp.getQueryId(); - //TODO - List solvedQueries = getSimilarQueries(queryContext.getQueryText(), + List solvedQueries = getSimilarQueries(chatParseContext.getQueryText(), null); ChatQueryDO chatQueryDO = getChatQuery(queryId); chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries)); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/RespBuildProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/RespBuildProcessor.java new file mode 100644 index 000000000..5d974995b --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/RespBuildProcessor.java @@ -0,0 +1,28 @@ +package com.tencent.supersonic.chat.server.processor.parse; + +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.headless.api.pojo.response.ParseResp; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections4.CollectionUtils; + +import java.util.List; + +/** + * RespBuildProcessor fill response object with parsing results. + **/ +@Slf4j +public class RespBuildProcessor implements ParseResultProcessor { + + @Override + public void process(ChatParseContext chatParseContext, ParseResp parseResp) { + parseResp.setChatId(chatParseContext.getChatId()); + parseResp.setQueryText(chatParseContext.getQueryText()); + List parseInfos = parseResp.getSelectedParses(); + if (CollectionUtils.isNotEmpty(parseInfos)) { + parseResp.setState(ParseResp.ParseState.COMPLETED); + } else { + parseResp.setState(ParseResp.ParseState.FAILED); + } + } +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java index a8bfe2495..b818acf50 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatServiceImpl.java @@ -7,22 +7,27 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq; import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp; -import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.executor.ChatExecutor; +import com.tencent.supersonic.chat.server.parser.ChatParser; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository; -import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor; import com.tencent.supersonic.chat.server.service.ChatService; +import com.tencent.supersonic.chat.server.util.ComponentFactory; +import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; -import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; @@ -33,7 +38,6 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; - import java.text.SimpleDateFormat; import java.util.ArrayList; import java.util.Comparator; @@ -55,32 +59,65 @@ public class ChatServiceImpl implements ChatService { @Autowired private ChatQueryService chatQueryService; @Autowired - private AgentService agentService; - @Autowired private SearchService searchService; + private List chatParsers = ComponentFactory.getChatParsers(); + private List chatExecutors = ComponentFactory.getChatExecutors(); + private List parseResultProcessors = ComponentFactory.getParseProcessors(); @Override public List search(ChatParseReq chatParseReq) { - QueryReq queryReq = buildSqlQueryReq(chatParseReq); + ChatParseContext chatParseContext = buildParseContext(chatParseReq); + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); return searchService.search(queryReq); } @Override public ParseResp performParsing(ChatParseReq chatParseReq) { - QueryReq queryReq = buildSqlQueryReq(chatParseReq); - ParseResp parseResp = chatQueryService.performParsing(queryReq); + ParseResp parseResp = new ParseResp(); + ChatParseContext chatParseContext = buildParseContext(chatParseReq); + for (ChatParser chatParser : chatParsers) { + chatParser.parse(chatParseContext, parseResp); + } + for (ParseResultProcessor processor : parseResultProcessors) { + processor.process(chatParseContext, parseResp); + } batchAddParse(chatParseReq, parseResp); return parseResp; } @Override - public QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception { - ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteReq); - QueryResult queryResult = chatQueryService.performExecution(executeQueryReq); + public QueryResult performExecution(ChatExecuteReq chatExecuteReq) { + QueryResult queryResult = new QueryResult(); + ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq); + for (ChatExecutor chatExecutor : chatExecutors) { + queryResult = chatExecutor.execute(chatExecuteContext); + if (queryResult != null) { + break; + } + } saveQueryResult(chatExecuteReq, queryResult); return queryResult; } + private ChatParseContext buildParseContext(ChatParseReq chatParseReq) { + ChatParseContext chatParseContext = new ChatParseContext(); + BeanMapper.mapper(chatParseReq, chatParseContext); + QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext); + MapResp mapResp = chatQueryService.performMapping(queryReq); + chatParseContext.setMapInfo(mapResp.getMapInfo()); + return chatParseContext; + } + + private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { + ChatExecuteContext chatExecuteContext = new ChatExecuteContext(); + BeanMapper.mapper(chatExecuteReq, chatExecuteContext); + ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId()); + SemanticParseInfo semanticParseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(), + SemanticParseInfo.class); + chatExecuteContext.setParseInfo(semanticParseInfo); + return chatExecuteContext; + } + @Override public Object queryData(QueryDataReq queryData, User user) throws Exception { return chatQueryService.executeDirectQuery(queryData, user); @@ -96,36 +133,6 @@ public class ChatServiceImpl implements ChatService { return chatQueryService.queryDimensionValue(dimensionValueReq, user); } - private QueryReq buildSqlQueryReq(ChatParseReq chatParseReq) { - QueryReq queryReq = new QueryReq(); - BeanMapper.mapper(chatParseReq, queryReq); - if (chatParseReq.getAgentId() == null) { - return queryReq; - } - Agent agent = agentService.getAgent(chatParseReq.getAgentId()); - if (agent == null) { - return queryReq; - } - if (agent.containsLLMParserTool()) { - queryReq.setEnableLLM(true); - } - queryReq.setDataSetIds(agent.getDataSetIds()); - return queryReq; - } - - private ExecuteQueryReq buildExecuteReq(ChatExecuteReq chatExecuteReq) { - ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId()); - SemanticParseInfo parseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class); - return ExecuteQueryReq.builder() - .queryId(chatExecuteReq.getQueryId()) - .chatId(chatExecuteReq.getChatId()) - .queryText(chatExecuteReq.getQueryText()) - .parseInfo(parseInfo) - .saveAnswer(chatExecuteReq.isSaveAnswer()) - .user(chatExecuteReq.getUser()) - .build(); - } - @Override public Boolean addChat(User user, String chatName, Integer agentId) { ChatDO chatDO = new ChatDO(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java index 1242c6270..33a5af0ad 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java @@ -1,7 +1,10 @@ package com.tencent.supersonic.chat.server.util; +import com.tencent.supersonic.chat.server.executor.ChatExecutor; +import com.tencent.supersonic.chat.server.parser.ChatParser; +import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer; import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor; -import com.tencent.supersonic.headless.server.processor.ResultProcessor; +import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.core.io.support.SpringFactoriesLoader; @@ -11,11 +14,14 @@ import java.util.List; @Slf4j public class ComponentFactory { - private static List parseProcessors = new ArrayList<>(); + private static List parseProcessors = new ArrayList<>(); private static List executeProcessors = new ArrayList<>(); + private static List chatParsers = new ArrayList<>(); + private static List chatExecutors = new ArrayList<>(); + private static List pluginRecognizers = new ArrayList<>(); - public static List getParseProcessors() { - return CollectionUtils.isEmpty(parseProcessors) ? init(ResultProcessor.class, + public static List getParseProcessors() { + return CollectionUtils.isEmpty(parseProcessors) ? init(ParseResultProcessor.class, parseProcessors) : parseProcessors; } @@ -24,6 +30,21 @@ public class ComponentFactory { ? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors; } + public static List getChatParsers() { + return CollectionUtils.isEmpty(chatParsers) + ? init(ChatParser.class, chatParsers) : chatParsers; + } + + public static List getChatExecutors() { + return CollectionUtils.isEmpty(chatExecutors) + ? init(ChatExecutor.class, chatExecutors) : chatExecutors; + } + + public static List getPluginRecognizers() { + return CollectionUtils.isEmpty(pluginRecognizers) + ? init(PluginRecognizer.class, pluginRecognizers) : pluginRecognizers; + } + private static List init(Class factoryType, List list) { list.addAll(SpringFactoriesLoader.loadFactories(factoryType, Thread.currentThread().getContextClassLoader())); @@ -34,4 +55,5 @@ public class ComponentFactory { return SpringFactoriesLoader.loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } + } \ No newline at end of file diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java new file mode 100644 index 000000000..fd8d8b6a6 --- /dev/null +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/QueryReqConverter.java @@ -0,0 +1,30 @@ +package com.tencent.supersonic.chat.server.util; + +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.pojo.ChatParseContext; +import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.common.util.BeanMapper; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.request.QueryReq; + +public class QueryReqConverter { + + public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) { + QueryReq queryReq = new QueryReq(); + BeanMapper.mapper(chatParseContext, queryReq); + if (chatParseContext.getAgentId() == null) { + return queryReq; + } + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent agent = agentService.getAgent(chatParseContext.getAgentId()); + if (agent == null) { + return queryReq; + } + if (agent.containsLLMParserTool()) { + queryReq.setEnableLLM(true); + } + queryReq.setDataSetIds(agent.getDataSetIds()); + return queryReq; + } + +} diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapResp.java new file mode 100644 index 000000000..8603ffda8 --- /dev/null +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MapResp.java @@ -0,0 +1,13 @@ +package com.tencent.supersonic.headless.api.pojo.response; + +import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import lombok.Data; + +@Data +public class MapResp { + + private String queryText; + + private SchemaMapInfo mapInfo = new SchemaMapInfo(); + +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/ChatQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/ChatQueryApiController.java index 7f8163d08..b9aa981bf 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/ChatQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/api/ChatQueryApiController.java @@ -11,7 +11,6 @@ import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; - import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -22,7 +21,6 @@ public class ChatQueryApiController { @Autowired private ChatQueryService chatQueryService; - @Autowired private SearchService searchService; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java index fd0f8db83..88f62435b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/ChatQueryService.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq; import com.tencent.supersonic.headless.api.pojo.request.QueryReq; +import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; @@ -15,6 +16,8 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult; */ public interface ChatQueryService { + MapResp performMapping(QueryReq queryReq); + ParseResp performParsing(QueryReq queryReq); QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java index 62e022649..8b98cafae 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ChatQueryServiceImpl.java @@ -32,6 +32,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryReq; import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.response.ExplainResp; +import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.QueryResult; import com.tencent.supersonic.headless.api.pojo.response.QueryState; @@ -106,6 +107,18 @@ public class ChatQueryServiceImpl implements ChatQueryService { private List semanticCorrectors = ComponentFactory.getSemanticCorrectors(); private List resultProcessors = ComponentFactory.getResultProcessors(); + @Override + public MapResp performMapping(QueryReq queryReq) { + MapResp mapResp = new MapResp(); + QueryContext queryCtx = buildQueryContext(queryReq); + schemaMappers.forEach(mapper -> { + mapper.map(queryCtx); + }); + SchemaMapInfo mapInfo = queryCtx.getMapInfo(); + mapResp.setMapInfo(mapInfo); + return mapResp; + } + @Override public ParseResp performParsing(QueryReq queryReq) { ParseResp parseResult = new ParseResp(); diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index b37edb18b..601d607a9 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -9,6 +9,14 @@ com.tencent.supersonic.headless.core.chat.parser.SemanticParser=\ com.tencent.supersonic.headless.core.chat.parser.llm.LLMSqlParser, \ com.tencent.supersonic.headless.core.chat.parser.QueryTypeParser +com.tencent.supersonic.chat.server.parser.ChatParser=\ + com.tencent.supersonic.chat.server.parser.Text2PluginParser, \ + com.tencent.supersonic.chat.server.parser.Text2SqlParser + +com.tencent.supersonic.chat.server.executor.ChatExecutor=\ + com.tencent.supersonic.chat.server.executor.PluginExecutor, \ + com.tencent.supersonic.chat.server.executor.SqlExecutor + com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector=\ com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector, \ com.tencent.supersonic.headless.core.chat.corrector.TimeCorrector, \ @@ -48,6 +56,13 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor +com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\ + com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer + +com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\ + com.tencent.supersonic.chat.server.processor.parse.RespBuildProcessor,\ + com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor + com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\ com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\ com.tencent.supersonic.chat.server.processor.execute.DimensionRecommendProcessor,\