mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 02:46:56 +00:00
(improvement)(Chat) Move chat-core to headless (#805)
Co-authored-by: jolunoluo
This commit is contained in:
@@ -16,16 +16,6 @@ public class ChatConfigBaseReq {
|
||||
|
||||
private Long modelId;
|
||||
|
||||
/**
|
||||
* the chatDetailConfig about the model
|
||||
*/
|
||||
private ChatDetailConfigReq chatDetailConfig;
|
||||
|
||||
/**
|
||||
* the chatAggConfig about the model
|
||||
*/
|
||||
private ChatAggConfigReq chatAggConfig;
|
||||
|
||||
|
||||
/**
|
||||
* the recommended questions about the model
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class ChatExecuteReq {
|
||||
private User user;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private boolean saveAnswer;
|
||||
|
||||
}
|
||||
@@ -1,15 +1,16 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.request;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class QueryReq {
|
||||
public class ChatParseReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Long dataSetId;
|
||||
private Integer agentId;
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private Integer agentId;
|
||||
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.api.pojo.response;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>chat</artifactId>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<version>${revision}</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>chat-core</artifactId>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-context</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.testng</groupId>
|
||||
<artifactId>testng</artifactId>
|
||||
<version>${org.testng.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.junit.jupiter</groupId>
|
||||
<artifactId>junit-jupiter</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-compress</artifactId>
|
||||
<version>${commons.compress.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-test</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.alibaba</groupId>
|
||||
<artifactId>druid</artifactId>
|
||||
<version>${alibaba.druid.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>mysql</groupId>
|
||||
<artifactId>mysql-connector-java</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.h2database</groupId>
|
||||
<artifactId>h2</artifactId>
|
||||
<version>${h2.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>headless-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>headless-core</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>chat-api</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>com.github.xkzhangsan</groupId>
|
||||
<artifactId>xk-time</artifactId>
|
||||
<version>${xk.time.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-inline</artifactId>
|
||||
<version>${mockito-inline.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>headless-server</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -1,66 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionPromptGenerator;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.OutputFormat;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGeneration;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.SqlGenerationFactory;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.SqlGenerationMode;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* LLMProxy based on langchain4j Java version.
|
||||
*/
|
||||
@Slf4j
|
||||
@Component
|
||||
public class JavaLLMProxy implements LLMProxy {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@Override
|
||||
public boolean isSkip(QueryContext queryContext) {
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
if (Objects.isNull(chatLanguageModel)) {
|
||||
log.warn("chatLanguageModel is null, skip :{}", JavaLLMProxy.class.getName());
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public LLMResp query2sql(LLMReq llmReq, Long dataSetId) {
|
||||
|
||||
SqlGeneration sqlGeneration = SqlGenerationFactory.get(
|
||||
SqlGenerationMode.getMode(llmReq.getSqlGenerationMode()));
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGeneration.generation(llmReq, dataSetId);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
|
||||
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||
functionReq.getPluginConfigs());
|
||||
keyPipelineLog.info("functionCallPrompt:{}", functionCallPrompt);
|
||||
String response = chatLanguageModel.generate(functionCallPrompt);
|
||||
keyPipelineLog.info("functionCall response:{}", response);
|
||||
return OutputFormat.functionCallParse(response);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionReq;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.FunctionResp;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
|
||||
/**
|
||||
* LLMProxy encapsulates functions performed by LLMs so that multiple
|
||||
* orchestration frameworks (e.g. LangChain in python, LangChain4j in java)
|
||||
* could be used.
|
||||
*/
|
||||
public interface LLMProxy {
|
||||
|
||||
boolean isSkip(QueryContext queryContext);
|
||||
|
||||
LLMResp query2sql(LLMReq llmReq, Long dataSetId);
|
||||
|
||||
FunctionResp requestFunction(FunctionReq functionReq);
|
||||
|
||||
}
|
||||
@@ -1,127 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* FunctionCallParser is an implementation of a recall plugin based on FunctionCall
|
||||
*/
|
||||
@Slf4j
|
||||
public class FunctionCallParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
FunctionCallConfig functionCallConfig = ContextUtils.getBean(FunctionCallConfig.class);
|
||||
String functionUrl = functionCallConfig.getUrl();
|
||||
if (StringUtils.isBlank(functionUrl) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||
log.info("functionUrl:{}, skip function parser, queryText:{}", functionUrl,
|
||||
queryContext.getQueryText());
|
||||
return false;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
FunctionResp functionResp = functionCall(queryContext);
|
||||
if (skipFunction(functionResp)) {
|
||||
return null;
|
||||
}
|
||||
log.info("requestFunction result:{}", functionResp.getToolSelection());
|
||||
String toolSelection = functionResp.getToolSelection();
|
||||
Plugin plugin = queryContext.getNameToPlugin().get(toolSelection);
|
||||
if (Objects.isNull(plugin)) {
|
||||
log.info("pluginOptional is not exist:{}, skip the parse", toolSelection);
|
||||
return null;
|
||||
}
|
||||
plugin.setParseMode(ParseMode.FUNCTION_CALL);
|
||||
Pair<Boolean, Set<Long>> pluginResolveResult = PluginManager.resolve(plugin, queryContext);
|
||||
if (pluginResolveResult.getLeft()) {
|
||||
Set<Long> dataSetList = pluginResolveResult.getRight();
|
||||
if (CollectionUtils.isEmpty(dataSetList)) {
|
||||
return null;
|
||||
}
|
||||
double score = queryContext.getQueryText().length();
|
||||
return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList).score(score).build();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public FunctionResp functionCall(QueryContext queryContext) {
|
||||
List<PluginParseConfig> pluginToFunctionCall =
|
||||
getPluginToFunctionCall(queryContext.getDataSetId(), queryContext);
|
||||
if (CollectionUtils.isEmpty(pluginToFunctionCall)) {
|
||||
log.info("function call parser, plugin is empty, skip");
|
||||
return null;
|
||||
}
|
||||
FunctionResp functionResp = new FunctionResp();
|
||||
if (pluginToFunctionCall.size() == 1) {
|
||||
functionResp.setToolSelection(pluginToFunctionCall.iterator().next().getName());
|
||||
} else {
|
||||
FunctionReq functionReq = FunctionReq.builder()
|
||||
.queryText(queryContext.getQueryText())
|
||||
.pluginConfigs(pluginToFunctionCall).build();
|
||||
functionResp = ComponentFactory.getLLMProxy().requestFunction(functionReq);
|
||||
}
|
||||
return functionResp;
|
||||
}
|
||||
|
||||
private boolean skipFunction(FunctionResp functionResp) {
|
||||
return Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection());
|
||||
}
|
||||
|
||||
private List<PluginParseConfig> getPluginToFunctionCall(Long modelId, QueryContext queryContext) {
|
||||
log.info("user decide Model:{}", modelId);
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<PluginParseConfig> functionDOList = plugins.stream().filter(plugin -> {
|
||||
if (LLMSqlQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) {
|
||||
return false;
|
||||
}
|
||||
if (plugin.getParseModeConfig() == null) {
|
||||
return false;
|
||||
}
|
||||
PluginParseConfig pluginParseConfig = JsonUtil.toObject(plugin.getParseModeConfig(),
|
||||
PluginParseConfig.class);
|
||||
if (StringUtils.isBlank(pluginParseConfig.getName())) {
|
||||
return false;
|
||||
}
|
||||
Pair<Boolean, Set<Long>> pluginResolverResult = PluginManager.resolve(plugin, queryContext);
|
||||
log.info("plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult);
|
||||
if (!pluginResolverResult.getLeft()) {
|
||||
return false;
|
||||
} else {
|
||||
Set<Long> resolveModel = pluginResolverResult.getRight();
|
||||
if (modelId != null && modelId > 0) {
|
||||
if (plugin.isContainsAllModel()) {
|
||||
return true;
|
||||
}
|
||||
return resolveModel.contains(modelId);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}).map(o -> JsonUtil.toObject(o.getParseModeConfig(), PluginParseConfig.class)).collect(Collectors.toList());
|
||||
log.info("PluginToFunctionCall: {}", JsonUtil.toString(functionDOList));
|
||||
return functionDOList;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.core.agent.NL2SQLTool;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMReq.ElementValue;
|
||||
import com.tencent.supersonic.chat.core.query.llm.s2sql.LLMResp;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class ParseResult {
|
||||
|
||||
private Long dataSetId;
|
||||
|
||||
private LLMReq llmReq;
|
||||
|
||||
private LLMResp llmResp;
|
||||
|
||||
private QueryReq request;
|
||||
|
||||
private NL2SQLTool commonAgentTool;
|
||||
|
||||
private List<ElementValue> linkingValues;
|
||||
}
|
||||
@@ -1,73 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.parser.sql.rule;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.core.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
public class AgentCheckParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
List<SemanticQuery> queries = queryContext.getCandidateQueries();
|
||||
log.info("query size before agent filter:{}", queryContext.getCandidateQueries().size());
|
||||
filterQueries(queryContext, queries);
|
||||
log.info("query size after agent filter: {}", queryContext.getCandidateQueries().size());
|
||||
}
|
||||
|
||||
private void filterQueries(QueryContext queryContext, List<SemanticQuery> queries) {
|
||||
Agent agent = queryContext.getAgent();
|
||||
if (agent == null) {
|
||||
return;
|
||||
}
|
||||
List<RuleParserTool> queryTools = getRuleTools(agent);
|
||||
if (CollectionUtils.isEmpty(queryTools)) {
|
||||
queryContext.setCandidateQueries(Lists.newArrayList());
|
||||
return;
|
||||
}
|
||||
log.info("agent name :{}, queries resolved: {}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
queries.removeIf(query -> {
|
||||
for (RuleParserTool tool : queryTools) {
|
||||
if (CollectionUtils.isNotEmpty(tool.getQueryModes())
|
||||
&& !tool.getQueryModes().contains(query.getQueryMode())) {
|
||||
return true;
|
||||
}
|
||||
if (CollectionUtils.isEmpty(tool.getDataSetIds())) {
|
||||
return true;
|
||||
}
|
||||
if (tool.isContainsAllModel()) {
|
||||
return false;
|
||||
}
|
||||
return !tool.getDataSetIds().contains(query.getParseInfo().getDataSetId());
|
||||
}
|
||||
return true;
|
||||
});
|
||||
queryContext.setCandidateQueries(queries);
|
||||
log.info("agent name :{}, rule queries witch can be supported by agent :{}", agent.getName(),
|
||||
queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
private static List<RuleParserTool> getRuleTools(Agent agent) {
|
||||
if (agent == null) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<String> tools = agent.getTools(AgentToolType.NL2SQL_RULE);
|
||||
if (CollectionUtils.isEmpty(tools)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleParserTool.class))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.llm;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public abstract class LLMSemanticQuery extends BaseSemanticQuery {
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.llm.analytics;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class LLMAnswerReq {
|
||||
|
||||
private String queryText;
|
||||
|
||||
private String pluginOutput;
|
||||
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.llm.analytics;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class LLMAnswerResp {
|
||||
private String assistantMessage;
|
||||
|
||||
}
|
||||
@@ -1,144 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.llm.analytics;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.Aggregator;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class MetricAnalyzeQuery extends LLMSemanticQuery {
|
||||
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_INTERPRET";
|
||||
|
||||
public MetricAnalyzeQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) throws SqlParseException {
|
||||
QueryStructReq queryStructReq = convertQueryStruct();
|
||||
SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
SemanticQueryResp semanticQueryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
|
||||
String text = generateTableText(semanticQueryResp);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
Map<String, String> replacedMap = new HashMap<>();
|
||||
String textReplaced = replaceText((String) properties.get("queryText"),
|
||||
parseInfo.getElementMatches(), replacedMap);
|
||||
String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap);
|
||||
QueryResult queryResult = new QueryResult();
|
||||
List<QueryColumn> queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer"));
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("answer", answer);
|
||||
List<Map<String, Object>> resultList = Lists.newArrayList();
|
||||
resultList.add(result);
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(queryColumns);
|
||||
queryResult.setQueryMode(getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
initS2SqlByStruct(semanticSchema);
|
||||
}
|
||||
|
||||
protected QueryStructReq convertQueryStruct() {
|
||||
QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo);
|
||||
fillAggregator(queryStructReq, parseInfo.getMetrics());
|
||||
queryStructReq.setQueryType(QueryType.TAG);
|
||||
return queryStructReq;
|
||||
}
|
||||
|
||||
private String replaceText(String text, List<SchemaElementMatch> schemaElementMatches,
|
||||
Map<String, String> replacedMap) {
|
||||
if (CollectionUtils.isEmpty(schemaElementMatches)) {
|
||||
return text;
|
||||
}
|
||||
List<SchemaElementMatch> valueSchemaElementMatches = schemaElementMatches.stream()
|
||||
.filter(schemaElementMatch ->
|
||||
SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType())
|
||||
|| SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
|
||||
.collect(Collectors.toList());
|
||||
for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) {
|
||||
String detectWord = schemaElementMatch.getDetectWord();
|
||||
if (StringUtils.isBlank(detectWord)) {
|
||||
continue;
|
||||
}
|
||||
text = text.replace(detectWord, "xxx");
|
||||
replacedMap.put("xxx", detectWord);
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
private void fillAggregator(QueryStructReq queryStructReq, Set<SchemaElement> schemaElements) {
|
||||
queryStructReq.getAggregators().clear();
|
||||
for (SchemaElement schemaElement : schemaElements) {
|
||||
Aggregator aggregator = new Aggregator();
|
||||
aggregator.setColumn(schemaElement.getBizName());
|
||||
aggregator.setFunc(AggOperatorEnum.SUM);
|
||||
aggregator.setNameCh(schemaElement.getName());
|
||||
queryStructReq.getAggregators().add(aggregator);
|
||||
}
|
||||
}
|
||||
|
||||
private String replaceAnswer(String text, Map<String, String> replacedMap) {
|
||||
for (String key : replacedMap.keySet()) {
|
||||
text = text.replaceAll(key, replacedMap.get(key));
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
public static String generateTableText(SemanticQueryResp result) {
|
||||
StringBuilder tableBuilder = new StringBuilder();
|
||||
for (QueryColumn column : result.getColumns()) {
|
||||
tableBuilder.append(column.getName()).append("\t");
|
||||
}
|
||||
tableBuilder.append("\n");
|
||||
for (Map<String, Object> row : result.getResultList()) {
|
||||
for (QueryColumn column : result.getColumns()) {
|
||||
tableBuilder.append(row.get(column.getNameEn())).append("\t");
|
||||
}
|
||||
tableBuilder.append("\n");
|
||||
}
|
||||
return tableBuilder.toString();
|
||||
}
|
||||
|
||||
public String fetchInterpret(String queryText, String dataText) {
|
||||
return "";
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SqlInfo;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.llm.LLMSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class LLMSqlQuery extends LLMSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "LLM_S2SQL";
|
||||
protected SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
public LLMSqlQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) {
|
||||
|
||||
long startTime = System.currentTimeMillis();
|
||||
String querySql = parseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
QuerySqlReq querySQLReq = QueryReqBuilder.buildS2SQLReq(querySql, parseInfo.getDataSetId());
|
||||
SemanticQueryResp queryResp = semanticInterpreter.queryByS2SQL(querySQLReq, user);
|
||||
|
||||
log.info("queryByS2SQL cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql);
|
||||
|
||||
QueryResult queryResult = new QueryResult();
|
||||
if (Objects.nonNull(queryResp)) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
}
|
||||
String resultQql = queryResp == null ? null : queryResp.getSql();
|
||||
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>() : queryResp.getResultList();
|
||||
List<QueryColumn> columns = queryResp == null ? new ArrayList<>() : queryResp.getColumns();
|
||||
queryResult.setQuerySql(resultQql);
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(columns);
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
parseInfo.setProperties(null);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
SqlInfo sqlInfo = parseInfo.getSqlInfo();
|
||||
sqlInfo.setCorrectS2SQL(sqlInfo.getS2SQL());
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin.webpage;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebPageQuery extends PluginSemanticQuery {
|
||||
|
||||
public static String QUERY_MODE = "WEB_PAGE";
|
||||
|
||||
public WebPageQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)),
|
||||
PluginParseResult.class);
|
||||
WebPageResp webPageResponse = buildResponse(pluginParseResult);
|
||||
queryResult.setResponse(webPageResponse);
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
protected WebPageResp buildResponse(PluginParseResult pluginParseResult) {
|
||||
Plugin plugin = pluginParseResult.getPlugin();
|
||||
WebPageResp webPageResponse = new WebPageResp();
|
||||
webPageResponse.setName(plugin.getName());
|
||||
webPageResponse.setPluginId(plugin.getId());
|
||||
webPageResponse.setPluginType(plugin.getType());
|
||||
WebBase webPage = JsonUtil.toObject(plugin.getConfig(), WebBase.class);
|
||||
WebBase webBase = fillWebBaseResult(webPage, pluginParseResult);
|
||||
webPageResponse.setWebPage(webBase);
|
||||
return webPageResponse;
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin.webservice;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.plugin.WebBase;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class WebServiceResp {
|
||||
|
||||
private WebBase webBase;
|
||||
|
||||
private Object result;
|
||||
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.rule.metric;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.OptionType.OPTIONAL;
|
||||
import static com.tencent.supersonic.chat.core.query.rule.QueryMatchOption.RequireNumberType.AT_MOST;
|
||||
@Component
|
||||
public class MetricModelQuery extends MetricSemanticQuery {
|
||||
|
||||
public static final String QUERY_MODE = "METRIC_MODEL";
|
||||
|
||||
public MetricModelQuery() {
|
||||
super();
|
||||
queryMatcher.addOption(SchemaElementType.DATASET, OPTIONAL, AT_MOST, 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) {
|
||||
QueryResult queryResult = super.execute(user);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.semantic;
|
||||
|
||||
import com.google.common.cache.Cache;
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Slf4j
|
||||
public abstract class BaseSemanticInterpreter implements SemanticInterpreter {
|
||||
|
||||
protected final Cache<String, List<DataSetSchemaResp>> dataSetSchemaCache =
|
||||
CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build();
|
||||
|
||||
@SneakyThrows
|
||||
public List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids, Boolean cacheEnable) {
|
||||
if (cacheEnable) {
|
||||
return dataSetSchemaCache.get(String.valueOf(ids), () -> {
|
||||
List<DataSetSchemaResp> data = doFetchDataSetSchema(ids);
|
||||
dataSetSchemaCache.put(String.valueOf(ids), data);
|
||||
return data;
|
||||
});
|
||||
}
|
||||
return doFetchDataSetSchema(ids);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataSetSchema getDataSetSchema(Long dataSetId, Boolean cacheEnable) {
|
||||
List<Long> ids = new ArrayList<>();
|
||||
ids.add(dataSetId);
|
||||
List<DataSetSchemaResp> dataSetSchemaResps = fetchDataSetSchema(ids, cacheEnable);
|
||||
if (!CollectionUtils.isEmpty(dataSetSchemaResps)) {
|
||||
Optional<DataSetSchemaResp> dataSetSchemaResp = dataSetSchemaResps.stream()
|
||||
.filter(d -> d.getId().equals(dataSetId)).findFirst();
|
||||
if (dataSetSchemaResp.isPresent()) {
|
||||
DataSetSchemaResp dataSetSchema = dataSetSchemaResp.get();
|
||||
return DataSetSchemaBuilder.build(dataSetSchema);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataSetSchema> getDataSetSchema() {
|
||||
return getDataSetSchema(new ArrayList<>());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataSetSchema> getDataSetSchema(List<Long> ids) {
|
||||
List<DataSetSchema> domainSchemaList = new ArrayList<>();
|
||||
|
||||
for (DataSetSchemaResp resp : fetchDataSetSchema(ids, true)) {
|
||||
domainSchemaList.add(DataSetSchemaBuilder.build(resp));
|
||||
}
|
||||
|
||||
return domainSchemaList;
|
||||
}
|
||||
|
||||
protected abstract List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids);
|
||||
|
||||
}
|
||||
@@ -1,106 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DataSetFilterReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.QueryService;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import java.util.List;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@Slf4j
|
||||
public class LocalSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
|
||||
private SchemaService schemaService;
|
||||
private DimensionService dimensionService;
|
||||
private MetricService metricService;
|
||||
private QueryService queryService;
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public SemanticQueryResp queryByStruct(QueryStructReq queryStructReq, User user) {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
if (queryStructReq.isConvertToSql()) {
|
||||
return queryService.queryByReq(queryStructReq.convert(), user);
|
||||
}
|
||||
return queryService.queryByReq(queryStructReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.queryByReq(queryMultiStructReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
public SemanticQueryResp queryByS2SQL(QuerySqlReq querySqlReq, User user) {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
SemanticQueryResp object = queryService.queryByReq(querySqlReq, user);
|
||||
return JsonUtil.toObject(JsonUtil.toString(object), SemanticQueryResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids) {
|
||||
DataSetFilterReq filter = new DataSetFilterReq();
|
||||
filter.setDataSetIds(ids);
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.fetchDataSetSchema(filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DomainResp> getDomainList(User user) {
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.getDomainList(user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataSetResp> getDataSetList(Long domainId) {
|
||||
schemaService = ContextUtils.getBean(SchemaService.class);
|
||||
return schemaService.getDataSetList(domainId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception {
|
||||
queryService = ContextUtils.getBean(QueryService.class);
|
||||
return queryService.explain(explainSqlReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
|
||||
dimensionService = ContextUtils.getBean(DimensionService.class);
|
||||
return dimensionService.queryDimension(pageDimensionCmd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricReq, User user) {
|
||||
metricService = ContextUtils.getBean(MetricService.class);
|
||||
return metricService.queryMetric(pageMetricReq, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ItemResp> getDomainDataSetTree() {
|
||||
return schemaService.getDomainDataSetTree();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.semantic;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.config.DefaultSemanticConfig;
|
||||
import com.tencent.supersonic.common.pojo.ResultData;
|
||||
import com.tencent.supersonic.common.pojo.enums.ReturnCode;
|
||||
import com.tencent.supersonic.common.pojo.exception.CommonException;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.S2ThreadContext;
|
||||
import com.tencent.supersonic.common.util.ThreadContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import java.net.URI;
|
||||
import java.net.URL;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
@Slf4j
|
||||
public class RemoteSemanticInterpreter extends BaseSemanticInterpreter {
|
||||
|
||||
private S2ThreadContext s2ThreadContext;
|
||||
|
||||
private AuthenticationConfig authenticationConfig;
|
||||
|
||||
private ParameterizedTypeReference<ResultData<SemanticQueryResp>> structTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<SemanticQueryResp>>() {
|
||||
};
|
||||
|
||||
private ParameterizedTypeReference<ResultData<ExplainResp>> explainTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<ExplainResp>>() {
|
||||
};
|
||||
|
||||
@Override
|
||||
public SemanticQueryResp queryByStruct(QueryStructReq queryStructReq, User user) {
|
||||
return queryByS2SQL(queryStructReq.convert(), user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchByMultiStructPath(),
|
||||
new Gson().toJson(queryMultiStructReq));
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryResp queryByS2SQL(QuerySqlReq querySqlReq, User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
return searchByRestTemplate(defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getSearchBySqlPath(),
|
||||
new Gson().toJson(querySqlReq));
|
||||
}
|
||||
|
||||
public SemanticQueryResp searchByRestTemplate(String url, String jsonReq) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
|
||||
log.info("url:{},searchByRestTemplate:{}", url, entity.getBody());
|
||||
ResultData<SemanticQueryResp> responseBody;
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
|
||||
ResponseEntity<ResultData<SemanticQueryResp>> responseEntity = restTemplate.exchange(
|
||||
requestUrl, HttpMethod.POST, entity, structTypeRef);
|
||||
responseBody = responseEntity.getBody();
|
||||
log.info("ApiResponse<QueryResultWithColumns> responseBody:{}", responseBody);
|
||||
SemanticQueryResp schemaResp = new SemanticQueryResp();
|
||||
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
|
||||
SemanticQueryResp data = responseBody.getData();
|
||||
schemaResp.setColumns(data.getColumns());
|
||||
schemaResp.setResultList(data.getResultList());
|
||||
schemaResp.setSql(data.getSql());
|
||||
schemaResp.setQueryAuthorization(data.getQueryAuthorization());
|
||||
return schemaResp;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("search headless interface error,url:" + url, e);
|
||||
}
|
||||
throw new CommonException(responseBody.getCode(), responseBody.getMsg());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DomainResp> getDomainList(User user) {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
Object domainDescListObject = fetchHttpResult(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchDomainListPath(),
|
||||
null, HttpMethod.GET);
|
||||
return JsonUtil.toList(JsonUtil.toString(domainDescListObject), DomainResp.class);
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> ExplainResp explain(ExplainSqlReq<T> explainResp, User user) throws Exception {
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
String semanticUrl = defaultSemanticConfig.getSemanticUrl();
|
||||
String explainPath = defaultSemanticConfig.getExplainPath();
|
||||
URL url = new URL(new URL(semanticUrl), explainPath);
|
||||
return explain(url.toString(), JsonUtil.toString(explainResp));
|
||||
}
|
||||
|
||||
public ExplainResp explain(String url, String jsonReq) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonReq, headers);
|
||||
log.info("url:{},explain:{}", url, entity.getBody());
|
||||
ResultData<ExplainResp> responseBody;
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
|
||||
ResponseEntity<ResultData<ExplainResp>> responseEntity = restTemplate.exchange(
|
||||
requestUrl, HttpMethod.POST, entity, explainTypeRef);
|
||||
log.info("ApiResponse<ExplainResp> responseBody:{}", responseEntity);
|
||||
responseBody = responseEntity.getBody();
|
||||
if (Objects.nonNull(responseBody.getData())) {
|
||||
return responseBody.getData();
|
||||
}
|
||||
return null;
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("explain interface error,url:" + url, e);
|
||||
}
|
||||
}
|
||||
|
||||
public Object fetchHttpResult(String url, String bodyJson, HttpMethod httpMethod) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
fillToken(headers);
|
||||
URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri();
|
||||
ParameterizedTypeReference<ResultData<Object>> responseTypeRef =
|
||||
new ParameterizedTypeReference<ResultData<Object>>() {
|
||||
};
|
||||
HttpEntity<String> entity = new HttpEntity<>(JsonUtil.toString(bodyJson), headers);
|
||||
try {
|
||||
RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class);
|
||||
ResponseEntity<ResultData<Object>> responseEntity = restTemplate.exchange(requestUrl,
|
||||
httpMethod, entity, responseTypeRef);
|
||||
ResultData<Object> responseBody = responseEntity.getBody();
|
||||
log.debug("ApiResponse<fetchModelSchema> responseBody:{}", responseBody);
|
||||
if (ReturnCode.SUCCESS.getCode() == responseBody.getCode()) {
|
||||
Object data = responseBody.getData();
|
||||
return data;
|
||||
}
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("fetchModelSchema interface error", e);
|
||||
}
|
||||
throw new RuntimeException("fetchModelSchema interface error");
|
||||
}
|
||||
|
||||
public void fillToken(HttpHeaders headers) {
|
||||
s2ThreadContext = ContextUtils.getBean(S2ThreadContext.class);
|
||||
authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class);
|
||||
ThreadContext threadContext = s2ThreadContext.get();
|
||||
if (Objects.nonNull(threadContext) && Strings.isNotEmpty(threadContext.getToken())) {
|
||||
if (Objects.nonNull(authenticationConfig) && Strings.isNotEmpty(
|
||||
authenticationConfig.getTokenHttpHeaderKey())) {
|
||||
headers.set(authenticationConfig.getTokenHttpHeaderKey(), threadContext.getToken());
|
||||
}
|
||||
} else {
|
||||
log.debug("threadContext is null:{}", Objects.isNull(threadContext));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<MetricResp> getMetricPage(PageMetricReq pageMetricCmd, User user) {
|
||||
String body = JsonUtil.toString(pageMetricCmd);
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
log.info("url:{}", defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath());
|
||||
Object dimensionListObject = fetchHttpResult(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath(),
|
||||
body, HttpMethod.POST);
|
||||
LinkedHashMap map = (LinkedHashMap) dimensionListObject;
|
||||
PageInfo<Object> metricDescObjectPageInfo = generatePageInfo(map);
|
||||
PageInfo<MetricResp> metricDescPageInfo = new PageInfo<>();
|
||||
BeanUtils.copyProperties(metricDescObjectPageInfo, metricDescPageInfo);
|
||||
metricDescPageInfo.setList(metricDescPageInfo.getList());
|
||||
return metricDescPageInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionCmd) {
|
||||
String body = JsonUtil.toString(pageDimensionCmd);
|
||||
DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class);
|
||||
Object dimensionListObject = fetchHttpResult(
|
||||
defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchDimensionPagePath(),
|
||||
body, HttpMethod.POST);
|
||||
LinkedHashMap map = (LinkedHashMap) dimensionListObject;
|
||||
PageInfo<Object> dimensionDescObjectPageInfo = generatePageInfo(map);
|
||||
PageInfo<DimensionResp> dimensionDescPageInfo = new PageInfo<>();
|
||||
BeanUtils.copyProperties(dimensionDescObjectPageInfo, dimensionDescPageInfo);
|
||||
dimensionDescPageInfo.setList(dimensionDescPageInfo.getList());
|
||||
return dimensionDescPageInfo;
|
||||
}
|
||||
|
||||
private PageInfo<Object> generatePageInfo(LinkedHashMap map) {
|
||||
PageInfo<Object> pageInfo = new PageInfo<>();
|
||||
pageInfo.setList((List<Object>) map.get(LIST_LOWER));
|
||||
Integer total = (Integer) map.get(TOTAL_LOWER);
|
||||
pageInfo.setTotal(total);
|
||||
Integer pageSize = (Integer) map.get(PAGESIZE_LOWER);
|
||||
pageInfo.setPageSize(pageSize);
|
||||
pageInfo.setPages((int) Math.ceil((double) total / pageSize));
|
||||
return pageInfo;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected List<DataSetSchemaResp> doFetchDataSetSchema(List<Long> ids) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ItemResp> getDomainDataSetTree() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataSetResp> getDataSetList(Long domainId) {
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.query.semantic;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExplainSqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMultiStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ExplainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetSchemaResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* A semantic layer provides a simplified and consistent dataSet of data from multiple sources.
|
||||
* It abstracts away the complexity of the underlying data sources and provides a unified dataSet
|
||||
* of the data that is easier to understand and use.
|
||||
* <p>
|
||||
* The interface defines methods for getting metadata as well as querying data in the semantic layer.
|
||||
* Implementations of this interface should provide concrete implementations that interact with the
|
||||
* underlying data sources and return results in a consistent format. Or it can be implemented
|
||||
* as proxy to a remote semantic service.
|
||||
* </p>
|
||||
*/
|
||||
public interface SemanticInterpreter {
|
||||
|
||||
SemanticQueryResp queryByStruct(QueryStructReq queryStructReq, User user);
|
||||
|
||||
SemanticQueryResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user);
|
||||
|
||||
SemanticQueryResp queryByS2SQL(QuerySqlReq querySQLReq, User user);
|
||||
|
||||
List<DataSetSchema> getDataSetSchema();
|
||||
|
||||
List<DataSetSchema> getDataSetSchema(List<Long> ids);
|
||||
|
||||
DataSetSchema getDataSetSchema(Long model, Boolean cacheEnable);
|
||||
|
||||
PageInfo<DimensionResp> getDimensionPage(PageDimensionReq pageDimensionReq);
|
||||
|
||||
PageInfo<MetricResp> getMetricPage(PageMetricReq pageDimensionReq, User user);
|
||||
|
||||
List<DomainResp> getDomainList(User user);
|
||||
|
||||
<T> ExplainResp explain(ExplainSqlReq<T> explainSqlReq, User user) throws Exception;
|
||||
|
||||
List<DataSetSchemaResp> fetchDataSetSchema(List<Long> ids, Boolean cacheEnable);
|
||||
|
||||
List<DataSetResp> getDataSetList(Long domainId);
|
||||
|
||||
List<ItemResp> getDomainDataSetTree();
|
||||
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
import com.tencent.supersonic.chat.core.parser.JavaLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.LLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.DataSetResolver;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
|
||||
private static SemanticInterpreter semanticInterpreter;
|
||||
private static LLMProxy llmProxy;
|
||||
private static DataSetResolver modelResolver;
|
||||
|
||||
public static SemanticInterpreter getSemanticLayer() {
|
||||
if (Objects.isNull(semanticInterpreter)) {
|
||||
semanticInterpreter = init(SemanticInterpreter.class);
|
||||
}
|
||||
return semanticInterpreter;
|
||||
}
|
||||
|
||||
public static LLMProxy getLLMProxy() {
|
||||
//1.Preferentially retrieve from environment variables
|
||||
String llmProxyEnv = System.getenv("llmProxy");
|
||||
if (StringUtils.isNotBlank(llmProxyEnv)) {
|
||||
Map<String, LLMProxy> implementations = ContextUtils.getBeansOfType(LLMProxy.class);
|
||||
llmProxy = implementations.entrySet().stream()
|
||||
.filter(entry -> entry.getKey().equalsIgnoreCase(llmProxyEnv))
|
||||
.map(Map.Entry::getValue)
|
||||
.findFirst()
|
||||
.orElse(null);
|
||||
}
|
||||
//2.default JavaLLMProxy
|
||||
if (Objects.isNull(llmProxy)) {
|
||||
llmProxy = ContextUtils.getBean(JavaLLMProxy.class);
|
||||
}
|
||||
log.info("llmProxy:{}", llmProxy);
|
||||
return llmProxy;
|
||||
}
|
||||
|
||||
public static DataSetResolver getModelResolver() {
|
||||
if (Objects.isNull(modelResolver)) {
|
||||
modelResolver = init(DataSetResolver.class);
|
||||
}
|
||||
return modelResolver;
|
||||
}
|
||||
|
||||
private static <T> T init(Class<T> factoryType) {
|
||||
return SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()).get(0);
|
||||
}
|
||||
}
|
||||
@@ -1,158 +0,0 @@
|
||||
package com.tencent.supersonic.chat.core.utils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.SimilarQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.logging.log4j.util.Strings;
|
||||
import org.springframework.core.ParameterizedTypeReference;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class SimilarQueryManager {
|
||||
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
|
||||
public SimilarQueryManager(EmbeddingConfig embeddingConfig) {
|
||||
this.embeddingConfig = embeddingConfig;
|
||||
}
|
||||
|
||||
public void saveSimilarQuery(SimilarQueryReq similarQueryReq) {
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
|
||||
return;
|
||||
}
|
||||
String queryText = similarQueryReq.getQueryText();
|
||||
try {
|
||||
String uniqueId = generateUniqueId(similarQueryReq.getQueryId(), similarQueryReq.getParseId());
|
||||
EmbeddingQuery embeddingQuery = new EmbeddingQuery();
|
||||
embeddingQuery.setQueryId(uniqueId);
|
||||
embeddingQuery.setQuery(queryText);
|
||||
|
||||
Map<String, Object> metaData = new HashMap<>();
|
||||
metaData.put("modelId", similarQueryReq.getDataSetId());
|
||||
metaData.put("agentId", similarQueryReq.getAgentId());
|
||||
embeddingQuery.setMetadata(metaData);
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
s2EmbeddingStore.addQuery(solvedQueryCollection, Lists.newArrayList(embeddingQuery));
|
||||
} catch (Exception e) {
|
||||
log.warn("save history question to embedding failed, queryText:{}", queryText, e);
|
||||
}
|
||||
}
|
||||
|
||||
public List<SimilarQueryRecallResp> recallSimilarQuery(String queryText, Integer agentId) {
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl())) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
List<SimilarQueryRecallResp> similarQueryRecallResps = Lists.newArrayList();
|
||||
try {
|
||||
String solvedQueryCollection = embeddingConfig.getSolvedQueryCollection();
|
||||
int solvedQueryResultNum = embeddingConfig.getSolvedQueryResultNum();
|
||||
|
||||
Map<String, String> filterCondition = new HashMap<>();
|
||||
filterCondition.put("agentId", String.valueOf(agentId));
|
||||
RetrieveQuery retrieveQuery = RetrieveQuery.builder()
|
||||
.queryTextsList(Lists.newArrayList(queryText))
|
||||
.filterCondition(filterCondition)
|
||||
.build();
|
||||
List<RetrieveQueryResult> resultList = s2EmbeddingStore.retrieveQuery(solvedQueryCollection, retrieveQuery,
|
||||
solvedQueryResultNum);
|
||||
|
||||
log.info("[embedding] recognize result body:{}", resultList);
|
||||
Set<String> querySet = new HashSet<>();
|
||||
if (CollectionUtils.isNotEmpty(resultList)) {
|
||||
for (RetrieveQueryResult retrieveQueryResult : resultList) {
|
||||
List<Retrieval> retrievals = retrieveQueryResult.getRetrieval();
|
||||
for (Retrieval retrieval : retrievals) {
|
||||
if (queryText.equalsIgnoreCase(retrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
if (querySet.contains(retrieval.getQuery())) {
|
||||
continue;
|
||||
}
|
||||
String id = retrieval.getId();
|
||||
SimilarQueryRecallResp similarQueryRecallResp = SimilarQueryRecallResp.builder()
|
||||
.queryText(retrieval.getQuery())
|
||||
.queryId(getQueryId(id)).parseId(getParseId(id))
|
||||
.build();
|
||||
similarQueryRecallResps.add(similarQueryRecallResp);
|
||||
querySet.add(retrieval.getQuery());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
log.warn("recall similar solved query failed, queryText:{}", queryText);
|
||||
}
|
||||
return similarQueryRecallResps;
|
||||
}
|
||||
|
||||
private String generateUniqueId(Long queryId, Integer parseId) {
|
||||
String uniqueId = queryId + String.valueOf(parseId);
|
||||
if (parseId < 10) {
|
||||
uniqueId = queryId + String.format("0%s", parseId);
|
||||
}
|
||||
return uniqueId;
|
||||
}
|
||||
|
||||
private Long getQueryId(String uniqueId) {
|
||||
return Long.parseLong(uniqueId) / 100;
|
||||
}
|
||||
|
||||
private Integer getParseId(String uniqueId) {
|
||||
return Integer.parseInt(uniqueId) % 100;
|
||||
}
|
||||
|
||||
private ResponseEntity<String> doRequest(String path, String jsonBody) {
|
||||
if (Strings.isEmpty(embeddingConfig.getUrl())) {
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
String url = embeddingConfig.getUrl() + path;
|
||||
try {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.setLocation(URI.create(url));
|
||||
URI requestUrl = UriComponentsBuilder
|
||||
.fromHttpUrl(url).build().encode().toUri();
|
||||
HttpEntity<String> entity = new HttpEntity<>(jsonBody, headers);
|
||||
log.info("[embedding] request body :{}, url:{}", jsonBody, url);
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
ResponseEntity<String> responseEntity = restTemplate.exchange(requestUrl,
|
||||
HttpMethod.POST, entity, new ParameterizedTypeReference<String>() {
|
||||
});
|
||||
log.info("[embedding] result body:{}", responseEntity);
|
||||
return responseEntity;
|
||||
} catch (Exception e) {
|
||||
log.warn("connect to embedding service failed, url:{}", url);
|
||||
}
|
||||
return ResponseEntity.of(Optional.empty());
|
||||
}
|
||||
|
||||
}
|
||||
@@ -6,7 +6,6 @@
|
||||
<packaging>pom</packaging>
|
||||
<modules>
|
||||
<module>api</module>
|
||||
<module>core</module>
|
||||
<module>server</module>
|
||||
</modules>
|
||||
|
||||
|
||||
@@ -11,19 +11,10 @@
|
||||
|
||||
<artifactId>chat-server</artifactId>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.springframework</groupId>
|
||||
<artifactId>spring-context</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>common</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>auth-api</artifactId>
|
||||
@@ -36,13 +27,7 @@
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>headless-core</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.tencent.supersonic</groupId>
|
||||
<artifactId>chat-core</artifactId>
|
||||
<artifactId>headless-server</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
@@ -51,12 +36,6 @@
|
||||
<artifactId>junit</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.mockito</groupId>
|
||||
<artifactId>mockito-inline</artifactId>
|
||||
<version>${mockito-inline.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
||||
@@ -1,8 +1,9 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -66,7 +67,11 @@ public class Agent extends RecordInfo {
|
||||
}
|
||||
|
||||
public Set<Long> getDataSetIds() {
|
||||
return getDataSetIds(null);
|
||||
Set<Long> dataSetIds = getDataSetIds(null);
|
||||
if (containsAllModel(dataSetIds)) {
|
||||
return Sets.newHashSet();
|
||||
}
|
||||
return dataSetIds;
|
||||
}
|
||||
|
||||
public Set<Long> getDataSetIds(AgentToolType agentToolType) {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.AllArgsConstructor;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.agent;
|
||||
package com.tencent.supersonic.chat.server.agent;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.config.DefaultMetric;
|
||||
import com.tencent.supersonic.chat.core.config.Dim4Dict;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import com.tencent.supersonic.headless.core.config.DefaultMetric;
|
||||
import com.tencent.supersonic.headless.core.config.Dim4Dict;
|
||||
import lombok.Data;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
@ToString
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextRepository {
|
||||
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatQueryRepository {
|
||||
@@ -25,8 +25,8 @@ public interface ChatQueryRepository {
|
||||
|
||||
int updateChatQuery(ChatQueryDO chatQueryDO);
|
||||
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
|
||||
ParseResp parseResult, List<SemanticParseInfo> candidateParses);
|
||||
List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
|
||||
List<SemanticParseInfo> candidateParses);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface StatisticsRepository {
|
||||
|
||||
void batchSaveStatistics(List<StatisticsDO> list);
|
||||
}
|
||||
@@ -1,16 +1,9 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.github.pagehelper.PageHelper;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDOExample;
|
||||
@@ -21,11 +14,10 @@ import com.tencent.supersonic.chat.server.persistence.mapper.custom.ShowCaseCust
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.common.util.PageUtils;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -33,6 +25,12 @@ import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@Slf4j
|
||||
@@ -108,21 +106,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
queryResult.setQueryId(chatQueryDO.getQuestionId());
|
||||
queryResp.setQueryResult(queryResult);
|
||||
}
|
||||
if (StringUtils.isNotBlank(chatQueryDO.getSimilarQueries())) {
|
||||
List<SimilarQueryRecallResp> similarQueries = JSONObject.parseArray(chatQueryDO.getSimilarQueries(),
|
||||
SimilarQueryRecallResp.class);
|
||||
queryResp.setSimilarQueries(similarQueries);
|
||||
}
|
||||
return queryResp;
|
||||
}
|
||||
|
||||
public Long createChatQuery(ParseResp parseResult, ChatContext chatCtx, QueryContext queryContext) {
|
||||
public Long createChatQuery(ParseResp parseResult, ChatParseReq chatParseReq) {
|
||||
ChatQueryDO chatQueryDO = new ChatQueryDO();
|
||||
chatQueryDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
||||
chatQueryDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
|
||||
chatQueryDO.setCreateTime(new java.util.Date());
|
||||
chatQueryDO.setUserName(queryContext.getUser().getName());
|
||||
chatQueryDO.setQueryText(queryContext.getQueryText());
|
||||
chatQueryDO.setAgentId(queryContext.getAgentId());
|
||||
chatQueryDO.setUserName(chatParseReq.getUser().getName());
|
||||
chatQueryDO.setQueryText(chatParseReq.getQueryText());
|
||||
chatQueryDO.setAgentId(chatParseReq.getAgentId());
|
||||
chatQueryDO.setQueryResult("");
|
||||
try {
|
||||
chatQueryDOMapper.insert(chatQueryDO);
|
||||
@@ -135,24 +128,24 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatContext chatCtx, QueryContext queryContext,
|
||||
public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq,
|
||||
ParseResp parseResult, List<SemanticParseInfo> candidateParses) {
|
||||
Long queryId = createChatQuery(parseResult, chatCtx, queryContext);
|
||||
Long queryId = createChatQuery(parseResult, chatParseReq);
|
||||
List<ChatParseDO> chatParseDOList = new ArrayList<>();
|
||||
getChatParseDO(chatCtx, queryContext, queryId, candidateParses, chatParseDOList);
|
||||
getChatParseDO(chatParseReq, queryId, candidateParses, chatParseDOList);
|
||||
if (!CollectionUtils.isEmpty(candidateParses)) {
|
||||
chatParseMapper.batchSaveParseInfo(chatParseDOList);
|
||||
}
|
||||
return chatParseDOList;
|
||||
}
|
||||
|
||||
public void getChatParseDO(ChatContext chatCtx, QueryContext queryContext, Long queryId,
|
||||
public void getChatParseDO(ChatParseReq chatParseReq, Long queryId,
|
||||
List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
|
||||
for (int i = 0; i < parses.size(); i++) {
|
||||
ChatParseDO chatParseDO = new ChatParseDO();
|
||||
chatParseDO.setChatId(Long.valueOf(chatCtx.getChatId()));
|
||||
chatParseDO.setChatId(Long.valueOf(chatParseReq.getChatId()));
|
||||
chatParseDO.setQuestionId(queryId);
|
||||
chatParseDO.setQueryText(queryContext.getQueryText());
|
||||
chatParseDO.setQueryText(chatParseReq.getQueryText());
|
||||
chatParseDO.setParseInfo(JsonUtil.toString(parses.get(i)));
|
||||
chatParseDO.setIsCandidate(1);
|
||||
if (i == 0) {
|
||||
@@ -160,7 +153,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
}
|
||||
chatParseDO.setParseId(parses.get(i).getId());
|
||||
chatParseDO.setCreateTime(new java.util.Date());
|
||||
chatParseDO.setUserName(queryContext.getUser().getName());
|
||||
chatParseDO.setUserName(chatParseReq.getUser().getName());
|
||||
chatParseDOList.add(chatParseDO);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.StatisticsRepository;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@Slf4j
|
||||
public class StatisticsRepositoryImpl implements StatisticsRepository {
|
||||
|
||||
private final StatisticsMapper statisticsMapper;
|
||||
|
||||
public StatisticsRepositoryImpl(StatisticsMapper statisticsMapper) {
|
||||
this.statisticsMapper = statisticsMapper;
|
||||
}
|
||||
|
||||
public void batchSaveStatistics(List<StatisticsDO> list) {
|
||||
statisticsMapper.batchSaveStatistics(list);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
public enum ParseMode {
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import lombok.Data;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -1,31 +1,36 @@
|
||||
package com.tencent.supersonic.chat.core.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.core.agent.PluginTool;
|
||||
import com.tencent.supersonic.chat.core.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.core.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.core.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.PluginTool;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.EmbeddingQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.event.EventListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@@ -43,21 +48,18 @@ import java.util.stream.Collectors;
|
||||
@Component
|
||||
public class PluginManager {
|
||||
|
||||
@Autowired
|
||||
private EmbeddingConfig embeddingConfig;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
public PluginManager(EmbeddingConfig embeddingConfig) {
|
||||
this.embeddingConfig = embeddingConfig;
|
||||
}
|
||||
public static List<Plugin> getPluginAgentCanSupport(ChatParseReq chatParseReq) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
|
||||
public static List<Plugin> getPluginAgentCanSupport(QueryContext queryContext) {
|
||||
List<Plugin> plugins = queryContext.getPluginList();
|
||||
if (Objects.isNull(queryContext.getAgent())) {
|
||||
return plugins;
|
||||
}
|
||||
Agent agent = queryContext.getAgent();
|
||||
if (agent == null) {
|
||||
List<Plugin> plugins = pluginService.getPluginList();
|
||||
if (Objects.isNull(agent)) {
|
||||
return plugins;
|
||||
}
|
||||
List<Long> pluginIds = getPluginTools(agent).stream().map(PluginTool::getPlugins)
|
||||
@@ -1,15 +1,14 @@
|
||||
package com.tencent.supersonic.chat.core.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.function.Parameters;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.ToString;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@@ -18,8 +17,6 @@ import lombok.ToString;
|
||||
@NoArgsConstructor
|
||||
public class PluginParseConfig implements Serializable {
|
||||
|
||||
public Parameters parameters;
|
||||
|
||||
public List<String> examples;
|
||||
|
||||
private String name;
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.core.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin.build;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin.build;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.chat.query.BaseSemanticQuery;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -18,11 +19,6 @@ import java.util.Map;
|
||||
@Slf4j
|
||||
public abstract class PluginSemanticQuery extends BaseSemanticQuery {
|
||||
|
||||
@Override
|
||||
public String explain(User user) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initS2Sql(SemanticSchema semanticSchema, User user) {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin.build;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@@ -0,0 +1,46 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webpage;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
public class WebPageQuery extends PluginSemanticQuery {
|
||||
|
||||
public static String QUERY_MODE = "WEB_PAGE";
|
||||
|
||||
public WebPageQuery() {
|
||||
QueryManager.register(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getQueryMode() {
|
||||
return QUERY_MODE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException {
|
||||
return null;
|
||||
}
|
||||
|
||||
protected WebPageResp buildResponse(PluginParseResult pluginParseResult) {
|
||||
Plugin plugin = pluginParseResult.getPlugin();
|
||||
WebPageResp webPageResponse = new WebPageResp();
|
||||
webPageResponse.setName(plugin.getName());
|
||||
webPageResponse.setPluginId(plugin.getId());
|
||||
webPageResponse.setPluginType(plugin.getType());
|
||||
WebBase webPage = JsonUtil.toObject(plugin.getConfig(), WebBase.class);
|
||||
WebBase webBase = fillWebBaseResult(webPage, pluginParseResult);
|
||||
webPageResponse.setWebPage(webBase);
|
||||
return webPageResponse;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin.webpage;
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webpage;
|
||||
|
||||
import com.tencent.supersonic.chat.core.query.plugin.WebBase;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@Data
|
||||
public class WebPageResp {
|
||||
|
||||
@@ -1,19 +1,15 @@
|
||||
package com.tencent.supersonic.chat.core.query.plugin.webservice;
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webservice;
|
||||
|
||||
import com.alibaba.fastjson.JSON;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.ParamOption;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.WebBase;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.core.chat.query.QueryManager;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
import org.springframework.http.HttpEntity;
|
||||
@@ -24,6 +20,7 @@ import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.stereotype.Component;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.util.UriComponentsBuilder;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
@@ -48,26 +45,8 @@ public class WebServiceQuery extends PluginSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(User user) throws SqlParseException {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setQueryMode(QUERY_MODE);
|
||||
Map<String, Object> properties = parseInfo.getProperties();
|
||||
PluginParseResult pluginParseResult = JsonUtil.toObject(
|
||||
JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
|
||||
WebServiceResp webServiceResponse = buildResponse(pluginParseResult);
|
||||
Object object = webServiceResponse.getResult();
|
||||
// in order to show webServiceQuery result int frontend conveniently,
|
||||
// webServiceResponse result format is consistent with queryByStruct result.
|
||||
log.info("webServiceResponse result:{}", JsonUtil.toString(object));
|
||||
try {
|
||||
Map<String, Object> data = JsonUtil.toMap(JsonUtil.toString(object), String.class, Object.class);
|
||||
queryResult.setQueryResults((List<Map<String, Object>>) data.get("resultList"));
|
||||
queryResult.setQueryColumns((List<QueryColumn>) data.get("columns"));
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
} catch (Exception e) {
|
||||
log.info("webServiceResponse result has an exception:{}", e.getMessage());
|
||||
}
|
||||
return queryResult;
|
||||
public SemanticQueryReq buildSemanticQueryReq() throws SqlParseException {
|
||||
return null;
|
||||
}
|
||||
|
||||
protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) {
|
||||
@@ -0,0 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.plugin.build.webservice;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import lombok.Data;
|
||||
|
||||
|
||||
@Data
|
||||
public class WebServiceResp {
|
||||
|
||||
private WebBase webBase;
|
||||
|
||||
private Object result;
|
||||
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.core.plugin.event;
|
||||
package com.tencent.supersonic.chat.server.plugin.event;
|
||||
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
public class PluginAddEvent extends ApplicationEvent {
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.plugin.event;
|
||||
package com.tencent.supersonic.chat.server.plugin.event;
|
||||
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
public class PluginDelEvent extends ApplicationEvent {
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.core.plugin.event;
|
||||
package com.tencent.supersonic.chat.server.plugin.event;
|
||||
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
public class PluginUpdateEvent extends ApplicationEvent {
|
||||
@@ -1,77 +1,68 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.plugin.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
/**
|
||||
* PluginParser defines the basic process and common methods for recalling plugins.
|
||||
*/
|
||||
public abstract class PluginParser implements SemanticParser {
|
||||
public abstract class PluginParser {
|
||||
|
||||
@Override
|
||||
public void parse(QueryContext queryContext, ChatContext chatContext) {
|
||||
for (SemanticQuery semanticQuery : queryContext.getCandidateQueries()) {
|
||||
if (queryContext.getQueryText().length() <= semanticQuery.getParseInfo().getScore()
|
||||
&& (QueryManager.getPluginQueryModes().contains(semanticQuery.getQueryMode()))) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
if (!checkPreCondition(queryContext)) {
|
||||
public void parse(ChatParseReq chatParseReq) {
|
||||
if (!checkPreCondition(chatParseReq)) {
|
||||
return;
|
||||
}
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(queryContext);
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseReq);
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(queryContext, pluginRecallResult);
|
||||
buildQuery(chatParseReq, pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(QueryContext queryContext);
|
||||
public abstract boolean checkPreCondition(ChatParseReq chatParseReq);
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(QueryContext queryContext);
|
||||
public abstract PluginRecallResult recallPlugin(ChatParseReq chatParseReq);
|
||||
|
||||
public void buildQuery(QueryContext queryContext, PluginRecallResult pluginRecallResult) {
|
||||
public void buildQuery(ChatParseReq chatParseReq, PluginRecallResult pluginRecallResult) {
|
||||
Plugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
if (plugin.isContainsAllModel()) {
|
||||
dataSetIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
PluginSemanticQuery pluginQuery = QueryManager.createPluginQuery(plugin.getType());
|
||||
//todo
|
||||
PluginSemanticQuery pluginQuery = null;
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
queryContext, pluginRecallResult.getDistance());
|
||||
null, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(pluginQuery.getQueryMode());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
pluginQuery.setParseInfo(semanticParseInfo);
|
||||
queryContext.getCandidateQueries().add(pluginQuery);
|
||||
//chatParseReq.getCandidateQueries().add(pluginQuery);
|
||||
}
|
||||
}
|
||||
|
||||
protected List<Plugin> getPluginList(QueryContext queryContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(queryContext);
|
||||
protected List<Plugin> getPluginList(ChatParseReq chatParseReq) {
|
||||
return PluginManager.getPluginAgentCanSupport(chatParseReq);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, Plugin plugin,
|
||||
@@ -1,18 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.core.parser.plugin.PluginParser;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.plugin.ParseMode;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.recall.PluginParser;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.PythonLLMProxy;
|
||||
import com.tencent.supersonic.headless.core.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -30,31 +30,30 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class EmbeddingRecallParser extends PluginParser {
|
||||
|
||||
@Override
|
||||
public boolean checkPreCondition(QueryContext queryContext) {
|
||||
public boolean checkPreCondition(ChatParseReq chatParseReq) {
|
||||
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
|
||||
if (StringUtils.isBlank(embeddingConfig.getUrl()) && ComponentFactory.getLLMProxy() instanceof PythonLLMProxy) {
|
||||
return false;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<Plugin> plugins = getPluginList(chatParseReq);
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PluginRecallResult recallPlugin(QueryContext queryContext) {
|
||||
String text = queryContext.getQueryText();
|
||||
public PluginRecallResult recallPlugin(ChatParseReq chatParseReq) {
|
||||
String text = chatParseReq.getQueryText();
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<Plugin> plugins = getPluginList(queryContext);
|
||||
List<Plugin> plugins = getPluginList(chatParseReq);
|
||||
Map<Long, Plugin> pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p));
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, queryContext);
|
||||
//todo
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, null);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> dataSetList = pair.getRight();
|
||||
@@ -63,7 +62,7 @@ public class EmbeddingRecallParser extends PluginParser {
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = queryContext.getQueryText().length() * (1 - distance);
|
||||
double score = chatParseReq.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,8 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.embedding;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.embedding;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import lombok.Data;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
@@ -1,12 +1,18 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import com.tencent.supersonic.chat.core.parser.sql.llm.InputFormat;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.core.chat.parser.llm.InputFormat;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class FunctionPromptGenerator {
|
||||
@@ -41,4 +47,29 @@ public class FunctionPromptGenerator {
|
||||
|
||||
return String.format("工具选择如下:\n\n%s\n\n【任务说明】\n%s", functionList, instruction);
|
||||
}
|
||||
|
||||
public FunctionResp requestFunction(FunctionReq functionReq) {
|
||||
|
||||
FunctionPromptGenerator promptGenerator = ContextUtils.getBean(FunctionPromptGenerator.class);
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ContextUtils.getBean(ChatLanguageModel.class);
|
||||
String functionCallPrompt = promptGenerator.generateFunctionCallPrompt(functionReq.getQueryText(),
|
||||
functionReq.getPluginConfigs());
|
||||
String response = chatLanguageModel.generate(functionCallPrompt);
|
||||
return functionCallParse(response);
|
||||
}
|
||||
|
||||
public static FunctionResp functionCallParse(String llmOutput) {
|
||||
try {
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
JsonNode jsonNode = objectMapper.readTree(llmOutput);
|
||||
String selectedTool = jsonNode.get("选择工具").asText();
|
||||
FunctionResp resp = new FunctionResp();
|
||||
resp.setToolSelection(selectedTool);
|
||||
return resp;
|
||||
} catch (Exception e) {
|
||||
log.error("", e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,11 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import java.util.List;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class FunctionReq {
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.chat.core.parser.plugin.function;
|
||||
package com.tencent.supersonic.chat.server.plugin.recall.function;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -5,4 +5,5 @@ package com.tencent.supersonic.chat.server.processor;
|
||||
*/
|
||||
public interface ResultProcessor {
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ package com.tencent.supersonic.chat.server.processor.execute;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
@@ -33,9 +33,9 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
||||
return;
|
||||
}
|
||||
SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
|
||||
List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
|
||||
queryResult.setRecommendedDimensions(dimensionRecommended);
|
||||
//SchemaElement element = semanticParseInfo.getMetrics().iterator().next();
|
||||
//List<SchemaElement> dimensionRecommended = getDimensions(element.getId(), element.getDataSet());
|
||||
//queryResult.setRecommendedDimensions(dimensionRecommended);
|
||||
}
|
||||
|
||||
private List<SchemaElement> getDimensions(Long metricId, Long dataSetId) {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,25 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.AggregateInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.MetricInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.core.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.core.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -29,8 +12,16 @@ import com.tencent.supersonic.common.pojo.enums.RatioOverType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.DateUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.core.config.AggregatorConfig;
|
||||
import com.tencent.supersonic.headless.core.utils.QueryReqBuilder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.text.DecimalFormat;
|
||||
import java.time.DayOfWeek;
|
||||
import java.time.LocalDate;
|
||||
@@ -48,8 +39,16 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.DAY_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.MONTH_FORMAT_INT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIMES_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
/**
|
||||
* Add ratio queries for metric queries.
|
||||
@@ -57,7 +56,7 @@ import org.springframework.util.CollectionUtils;
|
||||
@Slf4j
|
||||
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|
||||
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
//private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
@Override
|
||||
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
|
||||
@@ -68,8 +67,8 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||
return;
|
||||
}
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(queryReq.getUser(), semanticParseInfo, queryResult);
|
||||
queryResult.setAggregateInfo(aggregateInfo);
|
||||
//AggregateInfo aggregateInfo = getAggregateInfo(queryReq.getUser(), semanticParseInfo, queryResult);
|
||||
//queryResult.setAggregateInfo(aggregateInfo);
|
||||
}
|
||||
|
||||
public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) {
|
||||
@@ -133,7 +132,7 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
queryStructReq.setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult));
|
||||
queryStructReq.setConvertToSql(false);
|
||||
|
||||
SemanticQueryResp queryResp = semanticInterpreter.queryByStruct(queryStructReq, user);
|
||||
SemanticQueryResp queryResp = null;
|
||||
MetricInfo metricInfo = new MetricInfo();
|
||||
metricInfo.setStatistics(new HashMap<>());
|
||||
if (Objects.isNull(queryResp) || CollectionUtils.isEmpty(queryResp.getResultList())) {
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQuery;
|
||||
import com.tencent.supersonic.common.util.embedding.RetrieveQueryResult;
|
||||
import com.tencent.supersonic.common.util.embedding.S2EmbeddingStore;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.server.service.MetaEmbeddingService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.core.knowledge.MetaEmbeddingService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.Collections;
|
||||
@@ -33,8 +30,6 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
|
||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||
|
||||
private S2EmbeddingStore s2EmbeddingStore = ComponentFactory.getS2EmbeddingStore();
|
||||
|
||||
@Override
|
||||
public void process(QueryResult queryResult, SemanticParseInfo semanticParseInfo, ExecuteQueryReq queryReq) {
|
||||
fillSimilarMetric(queryResult.getChatContext());
|
||||
@@ -54,8 +49,7 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
.filterCondition(filterCondition).queryEmbeddings(null).build();
|
||||
MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class);
|
||||
List<RetrieveQueryResult> retrieveQueryResults =
|
||||
metaEmbeddingService.retrieveQuery(Lists.newArrayList(parseInfo.getDataSetId()),
|
||||
retrieveQuery, METRIC_RECOMMEND_SIZE + 1);
|
||||
metaEmbeddingService.retrieveQuery(retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>());
|
||||
if (CollectionUtils.isEmpty(retrieveQueryResults)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.core.query.QueryManager;
|
||||
import com.tencent.supersonic.chat.core.query.SemanticQuery;
|
||||
import com.tencent.supersonic.chat.core.query.llm.analytics.MetricAnalyzeQuery;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* EntityInfoProcessor fills core attributes of an entity so that
|
||||
* users get to know which entity is parsed out.
|
||||
*/
|
||||
public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
||||
List<SemanticQuery> semanticQueries = queryContext.getCandidateQueries();
|
||||
if (CollectionUtils.isEmpty(semanticQueries)) {
|
||||
return;
|
||||
}
|
||||
List<SemanticParseInfo> selectedParses = semanticQueries.stream().map(SemanticQuery::getParseInfo)
|
||||
.collect(Collectors.toList());
|
||||
selectedParses.forEach(parseInfo -> {
|
||||
String queryMode = parseInfo.getQueryMode();
|
||||
if (QueryManager.containsPluginQuery(queryMode)
|
||||
|| MetricAnalyzeQuery.QUERY_MODE.equalsIgnoreCase(queryMode)) {
|
||||
return;
|
||||
}
|
||||
//1. set entity info
|
||||
DataSetSchema dataSetSchema =
|
||||
queryContext.getSemanticSchema().getDataSetSchemaMap().get(parseInfo.getDataSetId());
|
||||
SemanticService semanticService = ContextUtils.getBean(SemanticService.class);
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, queryContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
|
||||
/**
|
||||
* A ParseResultProcessor wraps things up before returning results to users in parse stage.
|
||||
*/
|
||||
public interface ParseResultProcessor extends ResultProcessor {
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext);
|
||||
public interface ParseResultProcessor {
|
||||
|
||||
void process(ParseResp parseResp, ChatParseReq chatParseReq);
|
||||
|
||||
}
|
||||
|
||||
@@ -3,29 +3,30 @@ package com.tencent.supersonic.chat.server.processor.parse;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.core.utils.SimilarQueryManager;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.stream.Collectors;
|
||||
import lombok.SneakyThrows;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
/**
|
||||
* MetricRecommendProcessor fills recommended query based on embedding similarity.
|
||||
*/
|
||||
@Slf4j
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
public class QueryRecommendProcessor implements ResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, QueryContext queryContext, ChatContext chatContext) {
|
||||
@@ -35,8 +36,9 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, QueryContext queryContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
//TODO
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(queryContext.getQueryText(),
|
||||
queryContext.getAgentId());
|
||||
null);
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
updateChatQuery(chatQueryDO);
|
||||
@@ -44,8 +46,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
public List<SimilarQueryRecallResp> getSimilarQueries(String queryText, Integer agentId) {
|
||||
//1. recall solved query by queryText
|
||||
SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
|
||||
List<SimilarQueryRecallResp> similarQueries = solvedQueryManager.recallSimilarQuery(queryText, agentId);
|
||||
//SimilarQueryManager solvedQueryManager = ContextUtils.getBean(SimilarQueryManager.class);
|
||||
List<SimilarQueryRecallResp> similarQueries = Lists.newArrayList();
|
||||
if (CollectionUtils.isEmpty(similarQueries)) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
|
||||
@@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.core.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
@@ -8,16 +7,9 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageDimensionReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.PageMetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ItemResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
|
||||
import com.tencent.supersonic.headless.server.service.SchemaService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
@@ -39,8 +31,8 @@ public class ChatConfigController {
|
||||
@Autowired
|
||||
private ConfigService configService;
|
||||
|
||||
|
||||
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
|
||||
@PostMapping
|
||||
public Long addChatConfig(@RequestBody ChatConfigBaseReq extendBaseCmd,
|
||||
@@ -76,40 +68,9 @@ public class ChatConfigController {
|
||||
return configService.getAllChatRichConfig();
|
||||
}
|
||||
|
||||
@GetMapping("/domainList")
|
||||
public List<DomainResp> getDomainList(HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return semanticInterpreter.getDomainList(user);
|
||||
}
|
||||
|
||||
//Compatible with front-end
|
||||
@GetMapping("/dataSetList")
|
||||
public List<DataSetResp> getDataSetList() {
|
||||
return semanticInterpreter.getDataSetList(null);
|
||||
}
|
||||
|
||||
@GetMapping("/dataSetList/{domainId}")
|
||||
public List<DataSetResp> getDataSetList(@PathVariable("domainId") Long domainId) {
|
||||
return semanticInterpreter.getDataSetList(domainId);
|
||||
}
|
||||
|
||||
@PostMapping("/dimension/page")
|
||||
public PageInfo<DimensionResp> getDimension(@RequestBody PageDimensionReq pageDimensionReq) {
|
||||
return semanticInterpreter.getDimensionPage(pageDimensionReq);
|
||||
}
|
||||
|
||||
@PostMapping("/metric/page")
|
||||
public PageInfo<MetricResp> getMetric(@RequestBody PageMetricReq pageMetricReq,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return semanticInterpreter.getMetricPage(pageMetricReq, user);
|
||||
}
|
||||
|
||||
@GetMapping("/getDomainDataSetTree")
|
||||
public List<ItemResp> getDomainDataSetTree() {
|
||||
return semanticInterpreter.getDomainDataSetTree();
|
||||
return schemaService.getDomainDataSetTree();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ package com.tencent.supersonic.chat.server.rest;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.service.QueryService;
|
||||
import com.tencent.supersonic.chat.server.service.SearchService;
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.Valid;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import javax.servlet.http.HttpServletRequest;
|
||||
import javax.servlet.http.HttpServletResponse;
|
||||
import javax.validation.Valid;
|
||||
|
||||
/**
|
||||
* query controller
|
||||
*/
|
||||
@@ -27,62 +26,49 @@ import org.springframework.web.bind.annotation.RestController;
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
@Qualifier("chatQueryService")
|
||||
private QueryService queryService;
|
||||
|
||||
@Autowired
|
||||
private SearchService searchService;
|
||||
private ChatService chatService;
|
||||
|
||||
@PostMapping("search")
|
||||
public Object search(@RequestBody QueryReq queryCtx, HttpServletRequest request,
|
||||
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return searchService.search(queryCtx);
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.search(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("parse")
|
||||
public Object parse(@RequestBody QueryReq queryCtx, HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return queryService.performParsing(queryCtx);
|
||||
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("execute")
|
||||
public Object execute(@RequestBody ExecuteQueryReq queryReq,
|
||||
public Object execute(@RequestBody ChatExecuteReq chatExecuteReq,
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
queryReq.setUser(UserHolder.findUser(request, response));
|
||||
return queryService.performExecution(queryReq);
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryContext")
|
||||
public Object queryContext(@RequestBody QueryReq queryCtx, HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
public Object queryContext(@RequestBody QueryReq queryCtx,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return queryService.queryContext(queryCtx);
|
||||
return chatService.queryContext(queryCtx.getChatId());
|
||||
}
|
||||
|
||||
@PostMapping("queryData")
|
||||
public Object queryData(@RequestBody QueryDataReq queryData,
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
queryData.setUser(UserHolder.findUser(request, response));
|
||||
return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response));
|
||||
return chatService.queryData(queryData, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@PostMapping("queryDimensionValue")
|
||||
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
return queryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@RequestMapping("/getEntityInfo")
|
||||
public Object getEntityInfo(Long queryId, Integer parseId,
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
return queryService.getEntityInfo(queryId, parseId, user);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.rest;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.rest;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
|
||||
import com.tencent.supersonic.chat.server.service.RecommendService;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PathVariable;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* recommend controller
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping({"/api/chat/", "/openapi/chat/"})
|
||||
public class RecommendController {
|
||||
|
||||
@Autowired
|
||||
private RecommendService recommendService;
|
||||
|
||||
@GetMapping("recommend/{modelId}")
|
||||
public RecommendResp recommend(@PathVariable("modelId") Long modelId,
|
||||
@RequestParam(value = "limit", required = false) Long limit) {
|
||||
RecommendReq recommendReq = new RecommendReq();
|
||||
recommendReq.setModelId(modelId);
|
||||
return recommendService.recommend(recommendReq, limit);
|
||||
}
|
||||
|
||||
@GetMapping("recommend/metric/{modelId}")
|
||||
public RecommendResp recommendMetricMode(@PathVariable("modelId") Long modelId,
|
||||
@RequestParam(value = "metricId", required = false) Long metricId,
|
||||
@RequestParam(value = "limit", required = false) Long limit) {
|
||||
RecommendReq recommendReq = new RecommendReq();
|
||||
recommendReq.setModelId(modelId);
|
||||
recommendReq.setMetricId(metricId);
|
||||
return recommendService.recommendMetricMode(recommendReq, limit);
|
||||
}
|
||||
|
||||
@GetMapping("recommend/question")
|
||||
public List<RecommendQuestionResp> recommendQuestion(
|
||||
@RequestParam(value = "modelId", required = false) Long modelId) {
|
||||
return recommendService.recommendQuestion(modelId);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import java.util.List;
|
||||
|
||||
public interface AgentService {
|
||||
|
||||
@@ -2,30 +2,36 @@ package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatService {
|
||||
|
||||
/***
|
||||
* get the model from context
|
||||
* @param chatId
|
||||
* @return
|
||||
*/
|
||||
Long getContextModel(Integer chatId);
|
||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||
|
||||
ChatContext getOrCreateContext(int chatId);
|
||||
ParseResp performParsing(ChatParseReq chatParseReq);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception;
|
||||
|
||||
Object queryData(QueryDataReq queryData, User user) throws Exception;
|
||||
|
||||
SemanticParseInfo queryContext(Integer chatId);
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
Boolean addChat(User user, String chatName, Integer agentId);
|
||||
|
||||
@@ -45,13 +51,13 @@ public interface ChatService {
|
||||
|
||||
ShowCaseResp queryShowCase(PageQueryInfoReq pageQueryInfoCommend, int agentId);
|
||||
|
||||
List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult);
|
||||
List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult);
|
||||
|
||||
ChatQueryDO getLastQuery(long chatId);
|
||||
|
||||
int updateQuery(ChatQueryDO chatQueryDO);
|
||||
|
||||
void updateQuery(Long questionId, int parseId, QueryResult queryResult, ChatContext chatCtx);
|
||||
void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult);
|
||||
|
||||
ChatParseDO getParseInfo(Long questionId, int parseId);
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq;
|
||||
import org.apache.calcite.sql.parser.SqlParseException;
|
||||
|
||||
/***
|
||||
* QueryService for query and search
|
||||
*/
|
||||
public interface QueryService {
|
||||
|
||||
ParseResp performParsing(QueryReq queryReq);
|
||||
|
||||
QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception;
|
||||
|
||||
SemanticParseInfo queryContext(QueryReq queryReq);
|
||||
|
||||
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException;
|
||||
|
||||
EntityInfo getEntityInfo(Long queryId, Integer parseId, User user);
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
}
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SearchResult;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* search service
|
||||
*/
|
||||
public interface SearchService {
|
||||
|
||||
List<SearchResult> search(QueryReq queryCtx);
|
||||
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import java.lang.annotation.Documented;
|
||||
import java.lang.annotation.Target;
|
||||
import java.lang.annotation.ElementType;
|
||||
import java.lang.annotation.Retention;
|
||||
import java.lang.annotation.RetentionPolicy;
|
||||
|
||||
@Target({ElementType.PARAMETER, ElementType.METHOD})
|
||||
@Retention(RetentionPolicy.RUNTIME)
|
||||
@Documented
|
||||
public @interface TimeCost {
|
||||
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.aspectj.lang.ProceedingJoinPoint;
|
||||
import org.aspectj.lang.annotation.Around;
|
||||
import org.aspectj.lang.annotation.Aspect;
|
||||
import org.aspectj.lang.annotation.Pointcut;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Slf4j
|
||||
@Component
|
||||
@Aspect
|
||||
public class TimeCostAOP {
|
||||
|
||||
@Pointcut("@annotation(com.tencent.supersonic.chat.server.service.TimeCost)")
|
||||
private void timeCostAdvicePointcut() {
|
||||
|
||||
}
|
||||
|
||||
@Around("timeCostAdvicePointcut()")
|
||||
public Object timeCostAdvice(ProceedingJoinPoint joinPoint) throws Throwable {
|
||||
log.info("begin to add time cost!");
|
||||
Long startTime = System.currentTimeMillis();
|
||||
Object object = joinPoint.proceed();
|
||||
if (object instanceof QueryResult) {
|
||||
QueryResult queryResult = (QueryResult) object;
|
||||
queryResult.setQueryTimeCost(System.currentTimeMillis() - startTime);
|
||||
return queryResult;
|
||||
}
|
||||
return object;
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.AgentRepository;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
|
||||
@@ -1,26 +1,36 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.github.pagehelper.PageInfo;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PageQueryInfoReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ShowCaseResp;
|
||||
import com.tencent.supersonic.chat.core.pojo.QueryContext;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatParseDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.QueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatRepository;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.ExecuteQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.service.SearchService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -30,51 +40,87 @@ import java.util.Comparator;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@Service("ChatService")
|
||||
@Primary
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatServiceImpl implements ChatService {
|
||||
|
||||
private ChatContextRepository chatContextRepository;
|
||||
@Autowired
|
||||
private ChatRepository chatRepository;
|
||||
@Autowired
|
||||
private ChatQueryRepository chatQueryRepository;
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SearchService searchService;
|
||||
|
||||
public ChatServiceImpl(ChatContextRepository chatContextRepository, ChatRepository chatRepository,
|
||||
ChatQueryRepository chatQueryRepository) {
|
||||
this.chatContextRepository = chatContextRepository;
|
||||
this.chatRepository = chatRepository;
|
||||
this.chatQueryRepository = chatQueryRepository;
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
QueryReq queryReq = buildSqlQueryReq(chatParseReq);
|
||||
return searchService.search(queryReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getContextModel(Integer chatId) {
|
||||
if (Objects.isNull(chatId)) {
|
||||
return null;
|
||||
}
|
||||
ChatContext chatContext = getOrCreateContext(chatId);
|
||||
if (Objects.isNull(chatContext)) {
|
||||
return null;
|
||||
}
|
||||
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
|
||||
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
|
||||
return originalSemanticParse.getDataSetId();
|
||||
}
|
||||
return null;
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
QueryReq queryReq = buildSqlQueryReq(chatParseReq);
|
||||
ParseResp parseResp = chatQueryService.performParsing(queryReq);
|
||||
batchAddParse(chatParseReq, parseResp);
|
||||
return parseResp;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatContext getOrCreateContext(int chatId) {
|
||||
return chatContextRepository.getOrCreateContext(chatId);
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) throws Exception {
|
||||
ExecuteQueryReq executeQueryReq = buildExecuteReq(chatExecuteReq);
|
||||
QueryResult queryResult = chatQueryService.performExecution(executeQueryReq);
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
log.debug("save ChatContext {}", chatCtx);
|
||||
chatContextRepository.updateContext(chatCtx);
|
||||
public Object queryData(QueryDataReq queryData, User user) throws Exception {
|
||||
return chatQueryService.executeDirectQuery(queryData, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo queryContext(Integer chatId) {
|
||||
return chatQueryService.queryContext(chatId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception {
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
private QueryReq buildSqlQueryReq(ChatParseReq chatParseReq) {
|
||||
QueryReq queryReq = new QueryReq();
|
||||
BeanMapper.mapper(chatParseReq, queryReq);
|
||||
if (chatParseReq.getAgentId() == null) {
|
||||
return queryReq;
|
||||
}
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
if (agent == null) {
|
||||
return queryReq;
|
||||
}
|
||||
queryReq.setDataSetIds(agent.getDataSetIds());
|
||||
return queryReq;
|
||||
}
|
||||
|
||||
private ExecuteQueryReq buildExecuteReq(ChatExecuteReq chatExecuteReq) {
|
||||
ChatParseDO chatParseDO = getParseInfo(chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
SemanticParseInfo parseInfo = JSONObject.parseObject(chatParseDO.getParseInfo(), SemanticParseInfo.class);
|
||||
return ExecuteQueryReq.builder()
|
||||
.queryId(chatExecuteReq.getQueryId())
|
||||
.chatId(chatExecuteReq.getChatId())
|
||||
.queryText(chatExecuteReq.getQueryText())
|
||||
.parseInfo(parseInfo)
|
||||
.saveAnswer(chatExecuteReq.isSaveAnswer())
|
||||
.user(chatExecuteReq.getUser())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -190,18 +236,18 @@ public class ChatServiceImpl implements ChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateQuery(Long questionId, int parseId, QueryResult queryResult, ChatContext chatCtx) {
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
//The history record only retains the query result of the first parse
|
||||
if (parseId > 1) {
|
||||
if (chatExecuteReq.getParseId() > 1) {
|
||||
return;
|
||||
}
|
||||
ChatQueryDO chatQueryDO = new ChatQueryDO();
|
||||
chatQueryDO.setQuestionId(questionId);
|
||||
chatQueryDO.setQuestionId(chatExecuteReq.getQueryId());
|
||||
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
|
||||
chatQueryDO.setQueryState(1);
|
||||
updateQuery(chatQueryDO);
|
||||
chatRepository.updateLastQuestion(chatCtx.getChatId().longValue(),
|
||||
chatCtx.getQueryText(), getCurrentTime());
|
||||
chatRepository.updateLastQuestion(chatExecuteReq.getChatId().longValue(),
|
||||
chatExecuteReq.getQueryText(), getCurrentTime());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -210,9 +256,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ChatParseDO> batchAddParse(ChatContext chatCtx, QueryContext queryContext, ParseResp parseResult) {
|
||||
public List<ChatParseDO> batchAddParse(ChatParseReq chatParseReq, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> candidateParses = parseResult.getSelectedParses();
|
||||
return chatQueryRepository.batchSaveParseInfo(chatCtx, queryContext, parseResult, candidateParses);
|
||||
return chatQueryRepository.batchSaveParseInfo(chatParseReq, parseResult, candidateParses);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -3,9 +3,6 @@ package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq;
|
||||
@@ -24,23 +21,22 @@ import com.tencent.supersonic.chat.api.pojo.response.ChatDetailRichConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.EntityRichInfoResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ItemVisibilityInfo;
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.VisibilityEvent;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatConfigRepository;
|
||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import com.tencent.supersonic.chat.server.util.ChatConfigHelper;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaItem;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.service.DimensionService;
|
||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||
import com.tencent.supersonic.headless.server.service.impl.SemanticService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.ApplicationEventPublisher;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
@@ -62,10 +58,6 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
private final MetricService metricService;
|
||||
@Autowired
|
||||
private SemanticService semanticService;
|
||||
@Autowired
|
||||
private ApplicationEventPublisher applicationEventPublisher;
|
||||
|
||||
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
|
||||
public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
|
||||
@@ -83,9 +75,7 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
log.info("[create model extend] object:{}", JsonUtil.toString(configBaseCmd, true));
|
||||
duplicateCheck(configBaseCmd.getModelId());
|
||||
ChatConfig chaConfig = chatConfigHelper.newChatConfig(configBaseCmd, user);
|
||||
Long id = chatConfigRepository.createConfig(chaConfig);
|
||||
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
|
||||
return id;
|
||||
return chatConfigRepository.createConfig(chaConfig);
|
||||
}
|
||||
|
||||
private void duplicateCheck(Long modelId) {
|
||||
@@ -106,7 +96,6 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
}
|
||||
ChatConfig chaConfig = chatConfigHelper.editChatConfig(configEditCmd, user);
|
||||
chatConfigRepository.updateConfig(chaConfig);
|
||||
applicationEventPublisher.publishEvent(new VisibilityEvent(this, chaConfig));
|
||||
return configEditCmd.getId();
|
||||
}
|
||||
|
||||
@@ -350,15 +339,7 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
|
||||
@Override
|
||||
public List<ChatConfigRichResp> getAllChatRichConfig() {
|
||||
List<ChatConfigRichResp> chatConfigRichInfoList = new ArrayList<>();
|
||||
List<DataSetSchema> modelSchemas = semanticInterpreter.getDataSetSchema();
|
||||
modelSchemas.stream().forEach(modelSchema -> {
|
||||
ChatConfigRichResp chatConfigRichInfo = getConfigRichInfo(modelSchema.getDataSet().getId());
|
||||
if (Objects.nonNull(chatConfigRichInfo)) {
|
||||
chatConfigRichInfoList.add(chatConfigRichInfo);
|
||||
}
|
||||
});
|
||||
return chatConfigRichInfoList;
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -367,4 +348,5 @@ public class ConfigServiceImpl implements ConfigService {
|
||||
return allChatRichConfig.stream()
|
||||
.collect(Collectors.toMap(ChatConfigRichResp::getModelId, value -> value, (k1, k2) -> k1));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.PluginQueryReq;
|
||||
import com.tencent.supersonic.chat.core.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.core.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.core.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.core.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.core.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.Plugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.PluginRepository;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.headless.api.pojo.RelatedSchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigFilter;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.RecommendReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendQuestionResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.RecommendResp;
|
||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import com.tencent.supersonic.chat.server.service.RecommendService;
|
||||
import com.tencent.supersonic.chat.server.service.SemanticService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/***
|
||||
* Recommend Service impl
|
||||
*/
|
||||
@Service
|
||||
@Slf4j
|
||||
public class RecommendServiceImpl implements RecommendService {
|
||||
|
||||
@Autowired
|
||||
private ConfigService configService;
|
||||
@Autowired
|
||||
private SemanticService semanticService;
|
||||
|
||||
@Override
|
||||
public RecommendResp recommend(RecommendReq recommendReq, Long limit) {
|
||||
if (Objects.isNull(limit) || limit <= 0) {
|
||||
limit = Long.MAX_VALUE;
|
||||
}
|
||||
Long modelId = recommendReq.getModelId();
|
||||
if (Objects.isNull(modelId)) {
|
||||
return new RecommendResp();
|
||||
}
|
||||
DataSetSchema modelSchema = semanticService.getDataSetSchema(modelId);
|
||||
if (Objects.isNull(modelSchema)) {
|
||||
return new RecommendResp();
|
||||
}
|
||||
List<Long> drillDownDimensions = Lists.newArrayList();
|
||||
Set<SchemaElement> metricElements = modelSchema.getMetrics();
|
||||
if (recommendReq.getMetricId() != null && !CollectionUtils.isEmpty(metricElements)) {
|
||||
Optional<SchemaElement> metric = metricElements.stream().filter(schemaElement ->
|
||||
recommendReq.getMetricId().equals(schemaElement.getId())
|
||||
&& !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
|
||||
.findFirst();
|
||||
if (metric.isPresent()) {
|
||||
drillDownDimensions = metric.get().getRelatedSchemaElements().stream()
|
||||
.map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
final List<Long> drillDownDimensionsFinal = drillDownDimensions;
|
||||
List<SchemaElement> dimensions = modelSchema.getDimensions().stream()
|
||||
.filter(dim -> {
|
||||
if (Objects.isNull(dim)) {
|
||||
return false;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(drillDownDimensionsFinal)) {
|
||||
return drillDownDimensionsFinal.contains(dim.getId());
|
||||
} else {
|
||||
return Objects.nonNull(dim.getUseCnt());
|
||||
}
|
||||
})
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(limit)
|
||||
.map(dimSchemaDesc -> {
|
||||
SchemaElement item = new SchemaElement();
|
||||
item.setDataSet(modelId);
|
||||
item.setName(dimSchemaDesc.getName());
|
||||
item.setBizName(dimSchemaDesc.getBizName());
|
||||
item.setId(dimSchemaDesc.getId());
|
||||
item.setAlias(dimSchemaDesc.getAlias());
|
||||
return item;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
List<SchemaElement> metrics = modelSchema.getMetrics().stream()
|
||||
.filter(metric -> Objects.nonNull(metric) && Objects.nonNull(metric.getUseCnt()))
|
||||
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
|
||||
.limit(limit)
|
||||
.map(metricSchemaDesc -> {
|
||||
SchemaElement item = new SchemaElement();
|
||||
item.setDataSet(modelId);
|
||||
item.setName(metricSchemaDesc.getName());
|
||||
item.setBizName(metricSchemaDesc.getBizName());
|
||||
item.setId(metricSchemaDesc.getId());
|
||||
item.setAlias(metricSchemaDesc.getAlias());
|
||||
return item;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
RecommendResp response = new RecommendResp();
|
||||
response.setDimensions(dimensions);
|
||||
response.setMetrics(metrics);
|
||||
return response;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RecommendResp recommendMetricMode(RecommendReq recommendReq, Long limit) {
|
||||
return recommend(recommendReq, limit);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<RecommendQuestionResp> recommendQuestion(Long modelId) {
|
||||
List<RecommendQuestionResp> recommendQuestions = new ArrayList<>();
|
||||
ChatConfigFilter chatConfigFilter = new ChatConfigFilter();
|
||||
chatConfigFilter.setModelId(modelId);
|
||||
List<ChatConfigResp> chatConfigRespList = configService.search(chatConfigFilter, null);
|
||||
if (!CollectionUtils.isEmpty(chatConfigRespList)) {
|
||||
chatConfigRespList.stream().forEach(chatConfigResp -> {
|
||||
if (Objects.nonNull(chatConfigResp)
|
||||
&& !CollectionUtils.isEmpty(chatConfigResp.getRecommendedQuestions())) {
|
||||
recommendQuestions.add(
|
||||
new RecommendQuestionResp(chatConfigResp.getModelId(),
|
||||
chatConfigResp.getRecommendedQuestions()));
|
||||
}
|
||||
});
|
||||
return recommendQuestions;
|
||||
}
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
private List<SchemaElement> filterBlackItem(List<SchemaElement> itemList, List<Long> blackDimIdList) {
|
||||
if (CollectionUtils.isEmpty(blackDimIdList) || CollectionUtils.isEmpty(itemList)) {
|
||||
return itemList;
|
||||
}
|
||||
|
||||
return itemList.stream().filter(dim -> !blackDimIdList.contains(dim.getId())).collect(Collectors.toList());
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.google.common.cache.CacheBuilder;
|
||||
import com.google.common.cache.CacheLoader;
|
||||
import com.google.common.cache.LoadingCache;
|
||||
import com.tencent.supersonic.chat.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.utils.ComponentFactory;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class SchemaService {
|
||||
|
||||
|
||||
public static final String ALL_CACHE = "all";
|
||||
private static final Integer META_CACHE_TIME = 30;
|
||||
private SemanticInterpreter semanticInterpreter = ComponentFactory.getSemanticLayer();
|
||||
|
||||
private LoadingCache<String, SemanticSchema> cache = CacheBuilder.newBuilder()
|
||||
.expireAfterWrite(META_CACHE_TIME, TimeUnit.SECONDS)
|
||||
.build(
|
||||
new CacheLoader<String, SemanticSchema>() {
|
||||
@Override
|
||||
public SemanticSchema load(String key) {
|
||||
log.info("load getDomainSchemaInfo cache [{}]", key);
|
||||
return new SemanticSchema(semanticInterpreter.getDataSetSchema());
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
public DataSetSchema getDataSetSchema(Long id) {
|
||||
return semanticInterpreter.getDataSetSchema(id, true);
|
||||
}
|
||||
|
||||
public SemanticSchema getSemanticSchema() {
|
||||
return cache.getUnchecked(ALL_CACHE);
|
||||
}
|
||||
|
||||
public LoadingCache<String, SemanticSchema> getCache() {
|
||||
return cache;
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.StatisticsRepository;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
@@ -15,11 +15,11 @@ import java.util.List;
|
||||
public class StatisticsServiceImpl implements StatisticsService {
|
||||
|
||||
@Autowired
|
||||
private StatisticsRepository statisticsRepository;
|
||||
private StatisticsMapper statisticsMapper;
|
||||
|
||||
@Async
|
||||
@Override
|
||||
public void batchSaveStatistics(List<StatisticsDO> list) {
|
||||
statisticsRepository.batchSaveStatistics(list);
|
||||
statisticsMapper.batchSaveStatistics(list);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.chat.server.util;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.ADMIN_LOWER;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq;
|
||||
|
||||
@@ -1,43 +1,21 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.core.corrector.SemanticCorrector;
|
||||
import com.tencent.supersonic.chat.core.query.semantic.SemanticInterpreter;
|
||||
import com.tencent.supersonic.chat.core.mapper.SchemaMapper;
|
||||
import com.tencent.supersonic.chat.core.parser.SemanticParser;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import com.tencent.supersonic.headless.server.processor.ResultProcessor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.core.io.support.SpringFactoriesLoader;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class ComponentFactory {
|
||||
|
||||
private static List<SchemaMapper> schemaMappers = new ArrayList<>();
|
||||
private static List<SemanticParser> semanticParsers = new ArrayList<>();
|
||||
private static List<SemanticCorrector> semanticCorrectors = new ArrayList<>();
|
||||
private static SemanticInterpreter semanticInterpreter;
|
||||
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
||||
|
||||
public static List<SchemaMapper> getSchemaMappers() {
|
||||
return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers;
|
||||
}
|
||||
|
||||
public static List<SemanticParser> getSemanticParsers() {
|
||||
return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers;
|
||||
}
|
||||
|
||||
public static List<SemanticCorrector> getSemanticCorrectors() {
|
||||
return CollectionUtils.isEmpty(semanticCorrectors) ? init(SemanticCorrector.class,
|
||||
semanticCorrectors) : semanticCorrectors;
|
||||
}
|
||||
|
||||
public static List<ParseResultProcessor> getParseProcessors() {
|
||||
return CollectionUtils.isEmpty(parseProcessors) ? init(ParseResultProcessor.class,
|
||||
public static List<ResultProcessor> getParseProcessors() {
|
||||
return CollectionUtils.isEmpty(parseProcessors) ? init(ResultProcessor.class,
|
||||
parseProcessors) : parseProcessors;
|
||||
}
|
||||
|
||||
@@ -46,13 +24,6 @@ public class ComponentFactory {
|
||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||
}
|
||||
|
||||
public static SemanticInterpreter getSemanticLayer() {
|
||||
if (Objects.isNull(semanticInterpreter)) {
|
||||
semanticInterpreter = init(SemanticInterpreter.class);
|
||||
}
|
||||
return semanticInterpreter;
|
||||
}
|
||||
|
||||
private static <T> List<T> init(Class<T> factoryType, List list) {
|
||||
list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
|
||||
Thread.currentThread().getContextClassLoader()));
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.config.ChatConfig;
|
||||
import org.springframework.context.ApplicationEvent;
|
||||
|
||||
public class VisibilityEvent extends ApplicationEvent {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private ChatConfig chatConfig;
|
||||
|
||||
public VisibilityEvent(Object source, ChatConfig chatConfig) {
|
||||
super(source);
|
||||
this.chatConfig = chatConfig;
|
||||
}
|
||||
|
||||
public void setChatConfig(ChatConfig chatConfig) {
|
||||
this.chatConfig = chatConfig;
|
||||
}
|
||||
|
||||
public ChatConfig getChatConfig() {
|
||||
return chatConfig;
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.github.benmanes.caffeine.cache.Cache;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ItemNameVisibilityInfo;
|
||||
import com.tencent.supersonic.chat.server.service.ConfigService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.context.ApplicationListener;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
public class VisibilityListener implements ApplicationListener<VisibilityEvent> {
|
||||
|
||||
@Autowired
|
||||
@Qualifier("searchCaffeineCache")
|
||||
private Cache<Long, Object> caffeineCache;
|
||||
|
||||
@Autowired
|
||||
private ConfigService configService;
|
||||
|
||||
@Override
|
||||
public void onApplicationEvent(VisibilityEvent event) {
|
||||
log.info("visibility has changed,so update cache!");
|
||||
ItemNameVisibilityInfo itemNameVisibility = configService.getItemNameVisibility(event.getChatConfig());
|
||||
log.info("itemNameVisibility :{}", itemNameVisibility);
|
||||
caffeineCache.put(event.getChatConfig().getModelId(), itemNameVisibility);
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user