mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(Chat) Integrate chat with plugin recognizer and parse result processor (#820)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -12,5 +13,6 @@ public class ChatParseReq {
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatExecutor {
|
||||
|
||||
QueryResult execute(ChatExecuteContext chatExecuteContext);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public class PluginExecutor implements ChatExecutor {
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
PluginSemanticQuery query = PluginQueryManager.getPluginQuery(parseInfo.getQueryMode());
|
||||
return query.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
import lombok.SneakyThrows;
|
||||
|
||||
public class SqlExecutor implements ChatExecutor {
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
if (PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
return chatQueryService.performExecution(executeQueryReq);
|
||||
}
|
||||
|
||||
private ExecuteQueryReq buildExecuteReq(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
return ExecuteQueryReq.builder()
|
||||
.queryId(chatExecuteContext.getQueryId())
|
||||
.chatId(chatExecuteContext.getChatId())
|
||||
.queryText(chatExecuteContext.getQueryText())
|
||||
.parseInfo(parseInfo)
|
||||
.saveAnswer(chatExecuteContext.isSaveAnswer())
|
||||
.user(chatExecuteContext.getUser())
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatParser {
|
||||
|
||||
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import java.util.List;
|
||||
|
||||
public class Text2PluginParser implements ChatParser {
|
||||
|
||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
|
||||
public class Text2SqlParser implements ChatParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
|
||||
if (!ParseResp.ParseState.FAILED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.plugin;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.PluginTool;
|
||||
@@ -12,6 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
@@ -26,7 +26,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -53,10 +52,10 @@ public class PluginManager {
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
public static List<Plugin> getPluginAgentCanSupport(ChatParseReq chatParseReq) {
|
||||
public static List<Plugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
Agent agent = agentService.getAgent(chatParseContext.getAgentId());
|
||||
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
if (Objects.isNull(agent)) {
|
||||
@@ -199,9 +198,9 @@ public class PluginManager {
|
||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||
}
|
||||
|
||||
public static Pair<Boolean, Set<Long>> resolve(Plugin plugin, QueryContext queryContext) {
|
||||
SchemaMapInfo schemaMapInfo = queryContext.getMapInfo();
|
||||
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, queryContext);
|
||||
public static Pair<Boolean, Set<Long>> resolve(Plugin plugin, ChatParseContext chatParseContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedModel = getPluginMatchedModel(plugin, chatParseContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
@@ -267,8 +266,8 @@ public class PluginManager {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedModel(Plugin plugin, QueryContext queryContext) {
|
||||
Set<Long> matchedDataSets = queryContext.getMapInfo().getMatchedDataSetInfos();
|
||||
private static Set<Long> getPluginMatchedModel(Plugin plugin, ChatParseContext chatParseContext) {
|
||||
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class PluginQueryManager {
|
||||
|
||||
private static Map<String, PluginSemanticQuery> pluginQueries = new HashMap<>();
|
||||
|
||||
public static void register(String queryMode, PluginSemanticQuery pluginSemanticQuery) {
|
||||
pluginQueries.put(queryMode, pluginSemanticQuery);
|
||||
}
|
||||
|
||||
public static boolean isPluginQuery(String queryMode) {
|
||||
return pluginQueries.containsKey(queryMode);
|
||||
}
|
||||
|
||||
public static PluginSemanticQuery getPluginQuery(String queryMode) {
|
||||
return pluginQueries.get(queryMode);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,14 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.chat.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -17,12 +16,11 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
public abstract class PluginSemanticQuery extends BaseSemanticQuery {
|
||||
public abstract class PluginSemanticQuery {
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
protected SemanticParseInfo parseInfo;
|
||||
|
||||
}
|
||||
public abstract QueryResult build();
|
||||
|
||||
private Map<Long, Object> getFilterMap(PluginParseResult pluginParseResult) {
|
||||
Map<Long, Object> map = new HashMap<>();
|
||||
@@ -91,4 +89,8 @@ public abstract class PluginSemanticQuery extends BaseSemanticQuery {
|
||||
return webBaseResult;
|
||||
}
|
||||
|
||||
public void setParseInfo(SemanticParseInfo parseInfo) {
|
||||
this.parseInfo = parseInfo;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -2,15 +2,18 @@ package com.tencent.supersonic.chat.server.plugin.build.webpage;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebPageQuery extends PluginSemanticQuery {
|
||||
@@ -18,17 +21,7 @@ public class WebPageQuery extends PluginSemanticQuery {
|
||||
public static String QUERY_MODE = "WEB_PAGE";
|
||||
|
||||
public WebPageQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException {
|
||||
return null;
|
||||
PluginQueryManager.register(QUERY_MODE, this);
|
||||
}
|
||||
|
||||
protected WebPageResp buildResponse(PluginParseResult pluginParseResult) {
|
||||
@@ -43,4 +36,17 @@ public class WebPageQuery extends PluginSemanticQuery {
|
||||
return webPageResponse;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult build() {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)),
|
||||
PluginParseResult.class);
|
||||
WebPageResp webPageResponse = buildResponse(pluginParseResult);
|
||||
queryResult.setResponse(webPageResponse);
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -3,15 +3,17 @@ package com.tencent.supersonic.chat.server.plugin.build.webservice;
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
@@ -36,17 +38,30 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
private RestTemplate restTemplate;
|
||||
|
||||
public WebServiceQuery() {
|
||||
QueryManager.register(this);
|
||||
PluginQueryManager.register(QUERY_MODE, this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException {
|
||||
return null;
|
||||
public QueryResult build() {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(
|
||||
JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
|
||||
WebServiceResp webServiceResponse = buildResponse(pluginParseResult);
|
||||
Object object = webServiceResponse.getResult();
|
||||
// in order to show webServiceQuery result int frontend conveniently,
|
||||
// webServiceResponse result format is consistent with queryByStruct result.
|
||||
log.info("webServiceResponse result:{}", JsonUtil.toString(object));
|
||||
try {
|
||||
Map<String, Object> data = JsonUtil.toMap(JsonUtil.toString(object), String.class, Object.class);
|
||||
queryResult.setQueryResults((List<Map<String, Object>>) data.get("resultList"));
|
||||
queryResult.setQueryColumns((List<QueryColumn>) data.get("columns"));
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
} catch (Exception e) {
|
||||
log.info("webServiceResponse result has an exception:{}", e.getMessage());
|
||||
}
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) {
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
@Data
|
||||
public class FunctionCallConfig {
|
||||
@Value("${functionCall.url:}")
|
||||
private String url;
|
||||
|
||||
@Value("${funtionCall.plugin.select.path:/plugin_selection}")
|
||||
private String pluginSelectPath;
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class FunctionFiled {
|
||||
|
||||
private String type;
|
||||
|
||||
private String description;
|
||||
|
||||
}
|
||||
@@ -1,75 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.InputFormat;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class FunctionPromptGenerator {
|
||||
|
||||
public String generateFunctionCallPrompt(String queryText, List<PluginParseConfig> toolConfigList) {
|
||||
List<String> toolExplainList = toolConfigList.stream()
|
||||
.map(this::constructPluginPrompt)
|
||||
.collect(Collectors.toList());
|
||||
String functionList = String.join(InputFormat.SEPERATOR, toolExplainList);
|
||||
return constructTaskPrompt(queryText, functionList);
|
||||
}
|
||||
|
||||
public String constructPluginPrompt(PluginParseConfig parseConfig) {
|
||||
String toolName = parseConfig.getName();
|
||||
String toolDescription = parseConfig.getDescription();
|
||||
List<String> toolExamples = parseConfig.getExamples();
|
||||
|
||||
StringBuilder prompt = new StringBuilder();
|
||||
prompt.append("【工具名称】\n").append(toolName).append("\n");
|
||||
prompt.append("【工具描述】\n").append(toolDescription).append("\n");
|
||||
prompt.append("【工具适用问题示例】\n");
|
||||
for (String example : toolExamples) {
|
||||
prompt.append(example).append("\n");
|
||||
}
|
||||
return prompt.toString();
|
||||
}
|
||||
|
||||
public String constructTaskPrompt(String queryText, String functionList) {
|
||||
String instruction = String.format("问题为:%s\n请根据问题和工具的描述,选择对应的工具,完成任务。"
|
||||
+ "请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),"
|
||||
+ "并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘", queryText);
|
||||
|
||||
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
|
||||
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||
functionReq.getPluginConfigs());
|
||||
String response = chatLanguageModel.generate(functionCallPrompt);
|
||||
return functionCallParse(response);
|
||||
}
|
||||
|
||||
public static FunctionResp functionCallParse(String llmOutput) {
|
||||
try {
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
JsonNode jsonNode = objectMapper.readTree(llmOutput);
|
||||
String selectedTool = jsonNode.get("选择工具").asText();
|
||||
FunctionResp resp = new FunctionResp();
|
||||
resp.setToolSelection(selectedTool);
|
||||
return resp;
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class FunctionReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private List<PluginParseConfig> pluginConfigs;
|
||||
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class FunctionResp {
|
||||
|
||||
private String toolSelection;
|
||||
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Data
|
||||
public class Parameters {
|
||||
|
||||
//default: object
|
||||
private String type = "object";
|
||||
|
||||
private Map<String, FunctionFiled> properties;
|
||||
|
||||
private List<String> required;
|
||||
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall;
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
@@ -15,7 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -26,64 +25,58 @@ import java.util.Set;
|
||||
/**
|
||||
* PluginParser defines the basic process and common methods for recalling plugins.
|
||||
*/
|
||||
public abstract class PluginParser {
|
||||
public abstract class PluginRecognizer {
|
||||
|
||||
public void parse(ChatParseReq chatParseReq) {
|
||||
if (!checkPreCondition(chatParseReq)) {
|
||||
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(chatParseContext)) {
|
||||
return;
|
||||
}
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseReq);
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(chatParseReq, pluginRecallResult);
|
||||
buildQuery(chatParseContext, parseResp, pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(ChatParseReq chatParseReq);
|
||||
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(ChatParseReq chatParseReq);
|
||||
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
|
||||
|
||||
public void buildQuery(ChatParseReq chatParseReq, PluginRecallResult pluginRecallResult) {
|
||||
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
Plugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
dataSetIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
//todo
|
||||
PluginSemanticQuery pluginQuery = null;
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
null, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
chatParseContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(plugin.getType());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
pluginQuery.setParseInfo(semanticParseInfo);
|
||||
//chatParseReq.getCandidateQueries().add(pluginQuery);
|
||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(ChatParseReq chatParseReq) {
|
||||
return PluginManager.getPluginAgentCanSupport(chatParseReq);
|
||||
protected List<Plugin> getPluginList(ChatParseContext chatParseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(chatParseContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin,
|
||||
QueryContext queryContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = queryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = queryContext.getQueryFilters();
|
||||
if (dataSetId == null && !CollectionUtils.isEmpty(plugin.getDataSetList())) {
|
||||
dataSetId = plugin.getDataSetList().get(0);
|
||||
}
|
||||
ChatParseContext chatParseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = chatParseContext.getQueryFilters();
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
SemanticParseInfo semanticParseInfo = new SemanticParseInfo();
|
||||
semanticParseInfo.setElementMatches(schemaElementMatches);
|
||||
semanticParseInfo.setDataSet(queryContext.getSemanticSchema().getDataSet(dataSetId));
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
PluginParseResult pluginParseResult = new PluginParseResult();
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setQueryFilters(queryFilters);
|
||||
pluginParseResult.setDistance(distance);
|
||||
pluginParseResult.setQueryText(queryContext.getQueryText());
|
||||
pluginParseResult.setQueryText(chatParseContext.getQueryText());
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
@@ -111,4 +104,5 @@ public abstract class PluginParser {
|
||||
semanticParseInfo.getDimensionFilters().add(queryFilter);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.recall.PluginParser;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
@@ -28,32 +28,31 @@ import java.util.stream.Collectors;
|
||||
* EmbeddingRecallParser is an implementation of a recall plugin based on Embedding
|
||||
*/
|
||||
@Slf4j
|
||||
public class EmbeddingRecallParser extends PluginParser {
|
||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
|
||||
public boolean checkPreCondition(ChatParseReq chatParseReq) {
|
||||
public boolean checkPreCondition(ChatParseContext chatParseContext) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||
return false;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(chatParseReq);
|
||||
List<Plugin> plugins = getPluginList(chatParseContext);
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
public PluginRecallResult recallPlugin(ChatParseReq chatParseReq) {
|
||||
String text = chatParseReq.getQueryText();
|
||||
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
|
||||
String text = chatParseContext.getQueryText();
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(chatParseReq);
|
||||
List<Plugin> plugins = getPluginList(chatParseContext);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
}
|
||||
//todo
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, null);
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> dataSetList = pair.getRight();
|
||||
@@ -62,7 +61,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = chatParseReq.getQueryText().length() * (1 - distance);
|
||||
double score = chatParseContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
|
||||
package com.tencent.supersonic.chat.server.plugin.recognize.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatExecuteContext {
|
||||
private User user;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private boolean saveAnswer;
|
||||
private SemanticParseInfo parseInfo;
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatParseContext {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Integer agentId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ParseResultProcessor {
|
||||
|
||||
void process(ParseResp parseResp, ChatParseReq chatParseReq);
|
||||
void process(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
|
||||
@@ -7,12 +7,10 @@ import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -26,18 +24,17 @@ import java.util.stream.Collectors;
|
||||
* MetricRecommendProcessor fills recommended query based on embedding similarity.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryRecommendProcessor implements ResultProcessor {
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, queryContext));
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
|
||||
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
//TODO
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
|
||||
null);
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* RespBuildProcessor fill response object with parsing results.
|
||||
**/
|
||||
@Slf4j
|
||||
public class RespBuildProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
parseResp.setChatId(chatParseContext.getChatId());
|
||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
||||
List<SemanticParseInfo> parseInfos = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isNotEmpty(parseInfos)) {
|
||||
parseResp.setState(ParseResp.ParseState.COMPLETED);
|
||||
} else {
|
||||
parseResp.setState(ParseResp.ParseState.FAILED);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7,22 +7,27 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -33,7 +38,6 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.text.SimpleDateFormat;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
@@ -55,32 +59,65 @@ public class ChatServiceImpl implements ChatService {
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SearchService searchService;
|
||||
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
QueryReq queryReq = buildSqlQueryReq(chatParseReq);
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
return searchService.search(queryReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
QueryReq queryReq = buildSqlQueryReq(chatParseReq);
|
||||
ParseResp parseResp = chatQueryService.performParsing(queryReq);
|
||||
ParseResp parseResp = new ParseResp();
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
for (ChatParser chatParser : chatParsers) {
|
||||
chatParser.parse(chatParseContext, parseResp);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(chatParseContext, parseResp);
|
||||
}
|
||||
batchAddParse(chatParseReq, parseResp);
|
||||
return parseResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception {
|
||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteReq);
|
||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatExecutor chatExecutor : chatExecutors) {
|
||||
queryResult = chatExecutor.execute(chatExecuteContext);
|
||||
if (queryResult != null) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryReq);
|
||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
||||
return chatParseContext;
|
||||
}
|
||||
|
||||
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
|
||||
ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
SemanticParseInfo semanticParseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(),
|
||||
SemanticParseInfo.class);
|
||||
chatExecuteContext.setParseInfo(semanticParseInfo);
|
||||
return chatExecuteContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryData(QueryDataReq queryData, User user) throws Exception {
|
||||
return chatQueryService.executeDirectQuery(queryData, user);
|
||||
@@ -96,36 +133,6 @@ public class ChatServiceImpl implements ChatService {
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
private QueryReq buildSqlQueryReq(ChatParseReq chatParseReq) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
BeanMapper.mapper(chatParseReq, queryReq);
|
||||
if (chatParseReq.getAgentId() == null) {
|
||||
return queryReq;
|
||||
}
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
}
|
||||
if (agent.containsLLMParserTool()) {
|
||||
queryReq.setEnableLLM(true);
|
||||
}
|
||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
private ExecuteQueryReq buildExecuteReq(ChatExecuteReq chatExecuteReq) {
|
||||
ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
SemanticParseInfo parseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||
return ExecuteQueryReq.builder()
|
||||
.queryId(chatExecuteReq.getQueryId())
|
||||
.chatId(chatExecuteReq.getChatId())
|
||||
.queryText(chatExecuteReq.getQueryText())
|
||||
.parseInfo(parseInfo)
|
||||
.saveAnswer(chatExecuteReq.isSaveAnswer())
|
||||
.user(chatExecuteReq.getUser())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Boolean addChat(User user, String chatName, Integer agentId) {
|
||||
ChatDO chatDO = new ChatDO();
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
@@ -11,11 +14,14 @@ import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
private static List<ResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
||||
private static List<ChatParser> chatParsers = new ArrayList<>();
|
||||
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
|
||||
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
||||
|
||||
public static List<ResultProcessor> getParseProcessors() {
|
||||
return CollectionUtils.isEmpty(parseProcessors) ? init(ResultProcessor.class,
|
||||
public static List<ParseResultProcessor> getParseProcessors() {
|
||||
return CollectionUtils.isEmpty(parseProcessors) ? init(ParseResultProcessor.class,
|
||||
parseProcessors) : parseProcessors;
|
||||
}
|
||||
|
||||
@@ -24,6 +30,21 @@ public class ComponentFactory {
|
||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||
}
|
||||
|
||||
public static List<ChatParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatParsers)
|
||||
? init(ChatParser.class, chatParsers) : chatParsers;
|
||||
}
|
||||
|
||||
public static List<ChatExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatExecutors)
|
||||
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
|
||||
}
|
||||
|
||||
public static List<PluginRecognizer> getPluginRecognizers() {
|
||||
return CollectionUtils.isEmpty(pluginRecognizers)
|
||||
? init(PluginRecognizer.class, pluginRecognizers) : pluginRecognizers;
|
||||
}
|
||||
|
||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()));
|
||||
@@ -34,4 +55,5 @@ public class ComponentFactory {
|
||||
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()).get(0);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
public static QueryReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
BeanMapper.mapper(chatParseContext, queryReq);
|
||||
if (chatParseContext.getAgentId() == null) {
|
||||
return queryReq;
|
||||
}
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseContext.getAgentId());
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
}
|
||||
if (agent.containsLLMParserTool()) {
|
||||
queryReq.setEnableLLM(true);
|
||||
}
|
||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.headless.api.pojo.response;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class MapResp {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
|
||||
}
|
||||
@@ -11,7 +11,6 @@ import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
|
||||
@@ -22,7 +21,6 @@ public class ChatQueryApiController {
|
||||
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@Autowired
|
||||
private SearchService searchService;
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
@@ -15,6 +16,8 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
*/
|
||||
public interface ChatQueryService {
|
||||
|
||||
MapResp performMapping(QueryReq queryReq);
|
||||
|
||||
ParseResp performParsing(QueryReq queryReq);
|
||||
|
||||
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
|
||||
|
||||
@@ -32,6 +32,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
@@ -106,6 +107,18 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||
|
||||
@Override
|
||||
public MapResp performMapping(QueryReq queryReq) {
|
||||
MapResp mapResp = new MapResp();
|
||||
QueryContext queryCtx = buildQueryContext(queryReq);
|
||||
schemaMappers.forEach(mapper -> {
|
||||
mapper.map(queryCtx);
|
||||
});
|
||||
SchemaMapInfo mapInfo = queryCtx.getMapInfo();
|
||||
mapResp.setMapInfo(mapInfo);
|
||||
return mapResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(QueryReq queryReq) {
|
||||
ParseResp parseResult = new ParseResp();
|
||||
|
||||
@@ -9,6 +9,14 @@ com.tencent.supersonic.headless.core.chat.parser.SemanticParser=\
|
||||
com.tencent.supersonic.headless.core.chat.parser.llm.LLMSqlParser, \
|
||||
com.tencent.supersonic.headless.core.chat.parser.QueryTypeParser
|
||||
|
||||
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
||||
com.tencent.supersonic.chat.server.parser.Text2PluginParser, \
|
||||
com.tencent.supersonic.chat.server.parser.Text2SqlParser
|
||||
|
||||
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
||||
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
|
||||
com.tencent.supersonic.chat.server.executor.SqlExecutor
|
||||
|
||||
com.tencent.supersonic.headless.core.chat.corrector.SemanticCorrector=\
|
||||
com.tencent.supersonic.headless.core.chat.corrector.SchemaCorrector, \
|
||||
com.tencent.supersonic.headless.core.chat.corrector.TimeCorrector, \
|
||||
@@ -48,6 +56,13 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor
|
||||
com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\
|
||||
com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor
|
||||
|
||||
com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer=\
|
||||
com.tencent.supersonic.chat.server.plugin.recognize.embedding.EmbeddingRecallRecognizer
|
||||
|
||||
com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.parse.RespBuildProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.parse.QueryRecommendProcessor
|
||||
|
||||
com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor=\
|
||||
com.tencent.supersonic.chat.server.processor.execute.MetricRecommendProcessor,\
|
||||
com.tencent.supersonic.chat.server.processor.execute.DimensionRecommendProcessor,\
|
||||
|
||||
Reference in New Issue
Block a user