mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
(improvement)(headless&chat)Move ChatContext from Headless module to Chat module.
This commit is contained in:
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatExecutor {
|
||||
|
||||
QueryResult execute(ChatExecuteContext chatExecuteContext);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatQueryExecutor {
|
||||
|
||||
QueryResult execute(ExecuteContext executeContext);
|
||||
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -23,7 +23,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
public class PlainTextExecutor implements ChatExecutor {
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a nice person to talk to.\n"
|
||||
@@ -34,34 +34,34 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
+ "#Your response: ";
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
|
||||
chatExecuteContext.getQueryText());
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
result.setQueryState(QueryState.SUCCESS);
|
||||
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||
result.setQueryMode(executeContext.getParseInfo().getQueryMode());
|
||||
result.setTextResult(response.content().text());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
|
||||
private String getHistoryInputs(ExecuteContext executeContext) {
|
||||
StringBuilder historyInput = new StringBuilder();
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
||||
@@ -70,7 +70,7 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
|
||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
|
||||
List<ParseResp> parseResps = getHistoryParseResult(executeContext.getChatId(), 5);
|
||||
parseResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
|
||||
@@ -2,15 +2,15 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public class PluginExecutor implements ChatExecutor {
|
||||
public class PluginExecutor implements ChatQueryExecutor {
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -12,10 +12,10 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import lombok.SneakyThrows;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -25,12 +25,12 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class SqlExecutor implements ChatExecutor {
|
||||
public class SqlExecutor implements ChatQueryExecutor {
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
QueryResult queryResult = doExecute(chatExecuteContext);
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
QueryResult queryResult = doExecute(executeContext);
|
||||
|
||||
if (queryResult != null) {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
@@ -41,13 +41,13 @@ public class SqlExecutor implements ChatExecutor {
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(chatExecuteContext.getAgentId())
|
||||
.agentId(executeContext.getAgent().getId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(chatExecuteContext.getQueryText())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
|
||||
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||
.createdBy(chatExecuteContext.getUser().getName())
|
||||
.updatedBy(chatExecuteContext.getUser().getName())
|
||||
.question(executeContext.getQueryText())
|
||||
.s2sql(executeContext.getParseInfo().getSqlInfo().getParsedS2SQL())
|
||||
.dbSchema(buildSchemaStr(executeContext.getParseInfo()))
|
||||
.createdBy(executeContext.getUser().getName())
|
||||
.updatedBy(executeContext.getUser().getName())
|
||||
.createdAt(new Date())
|
||||
.build());
|
||||
}
|
||||
@@ -57,12 +57,12 @@ public class SqlExecutor implements ChatExecutor {
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private QueryResult doExecute(ChatExecuteContext chatExecuteContext) {
|
||||
private QueryResult doExecute(ExecuteContext executeContext) {
|
||||
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
if (Objects.isNull(parseInfo.getSqlInfo())
|
||||
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
return null;
|
||||
@@ -74,8 +74,10 @@ public class SqlExecutor implements ChatExecutor {
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
long startTime = System.currentTimeMillis();
|
||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, chatExecuteContext.getUser());
|
||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setChatContext(parseInfo);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
|
||||
@@ -85,7 +87,6 @@ public class SqlExecutor implements ChatExecutor {
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(columns);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
chatCtx.setParseInfo(parseInfo);
|
||||
@@ -94,7 +95,6 @@ public class SqlExecutor implements ChatExecutor {
|
||||
queryResult.setQueryState(QueryState.INVALID);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
}
|
||||
queryResult.setChatContext(chatCtx.getParseInfo());
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatParser {
|
||||
|
||||
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatQueryParser {
|
||||
|
||||
void parse(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -9,18 +9,18 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class NL2PluginParser implements ChatParser {
|
||||
public class NL2PluginParser implements ChatQueryParser {
|
||||
|
||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.getAgent().containsPluginTool()) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.getAgent().containsPluginTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||
pluginRecognizer.recognize(parseContext, parseResp);
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
});
|
||||
|
||||
@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.server.parser;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
@@ -17,7 +18,8 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
@@ -42,7 +44,7 @@ import java.util.stream.Collectors;
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@@ -73,27 +75,30 @@ public class NL2SQLParser implements ChatParser {
|
||||
+ "#Response: ";
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
chatParseContext.getAgent().getModelConfig());
|
||||
parseContext.getAgent().getModelConfig());
|
||||
|
||||
processMultiTurn(chatLanguageModel, chatParseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
|
||||
processMultiTurn(chatLanguageModel, parseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
|
||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
|
||||
chatParseContext.getQueryText(),
|
||||
parseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(),
|
||||
queryNLReq.getDynamicExemplars(),
|
||||
chatParseContext.getAgent().getExamples()));
|
||||
parseContext.getAgent().getExamples()));
|
||||
}
|
||||
parseResp.setState(text2SqlParseResp.getState());
|
||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||
@@ -155,9 +160,9 @@ public class NL2SQLParser implements ChatParser {
|
||||
parseInfo.setTextInfo(textBuilder.toString());
|
||||
}
|
||||
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||
@@ -167,11 +172,11 @@ public class NL2SQLParser implements ChatParser {
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp currentMapResult = chatQueryService.performMapping(queryNLReq);
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
|
||||
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1);
|
||||
if (historyParseResults.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -196,7 +201,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
String rewrittenQuery = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
|
||||
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
parseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service("ChatParserConfig")
|
||||
@Service("ChatQueryParserConfig")
|
||||
@Slf4j
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
|
||||
public class PlainTextParser implements ChatParser {
|
||||
public class PlainTextParser implements ChatQueryParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (chatParseContext.getAgent().containsAnyTool()) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (parseContext.getAgent().containsAnyTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.dataobject;
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.dataobject;
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.mapper;
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.mapper;
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.repository;
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextRepository {
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package com.tencent.supersonic.headless.server.persistence.repository.impl;
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
@@ -183,7 +183,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
public List<ParseResp> getContextualParseInfo(Integer chatId) {
|
||||
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
|
||||
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
|
||||
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
|
||||
ParseResp parseResp = new ParseResp(parseInfo.getQueryText());
|
||||
List<SemanticParseInfo> selectedParses = new ArrayList<>();
|
||||
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
|
||||
parseResp.setSelectedParses(selectedParses);
|
||||
|
||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
@@ -52,9 +52,9 @@ public class PluginManager {
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
Agent agent = parseContext.getAgent();
|
||||
List<ChatPlugin> plugins = pluginService.getPluginList();
|
||||
if (Objects.isNull(agent)) {
|
||||
return plugins;
|
||||
@@ -191,9 +191,9 @@ public class PluginManager {
|
||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||
}
|
||||
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
||||
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
@@ -259,8 +259,8 @@ public class PluginManager {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
|
||||
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -28,22 +28,22 @@ import java.util.Set;
|
||||
*/
|
||||
public abstract class PluginRecognizer {
|
||||
|
||||
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(chatParseContext)) {
|
||||
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(parseContext)) {
|
||||
return;
|
||||
}
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(parseContext);
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(chatParseContext, parseResp, pluginRecallResult);
|
||||
buildQuery(parseContext, parseResp, pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
|
||||
public abstract boolean checkPreCondition(ParseContext parseContext);
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
|
||||
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
|
||||
|
||||
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
|
||||
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
@@ -52,21 +52,21 @@ public abstract class PluginRecognizer {
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
chatParseContext, pluginRecallResult.getDistance());
|
||||
parseContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(plugin.getType());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
protected List<ChatPlugin> getPluginList(ChatParseContext chatParseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(chatParseContext);
|
||||
protected List<ChatPlugin> getPluginList(ParseContext parseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(parseContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||
ChatParseContext chatParseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = chatParseContext.getQueryFilters();
|
||||
ParseContext parseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
@@ -80,7 +80,7 @@ public abstract class PluginRecognizer {
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setQueryFilters(queryFilters);
|
||||
pluginParseResult.setDistance(distance);
|
||||
pluginParseResult.setQueryText(chatParseContext.getQueryText());
|
||||
pluginParseResult.setQueryText(parseContext.getQueryText());
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
@@ -26,25 +26,25 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
|
||||
public boolean checkPreCondition(ChatParseContext chatParseContext) {
|
||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
||||
public boolean checkPreCondition(ParseContext parseContext) {
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
|
||||
String text = chatParseContext.getQueryText();
|
||||
public PluginRecallResult recallPlugin(ParseContext parseContext) {
|
||||
String text = parseContext.getQueryText();
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, parseContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> dataSetList = pair.getRight();
|
||||
@@ -53,7 +53,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = chatParseContext.getQueryText().length() * (1 - distance);
|
||||
double score = parseContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.tencent.supersonic.headless.chat;
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
@@ -6,7 +6,6 @@ import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatContext {
|
||||
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
@@ -1,17 +1,17 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatExecuteContext {
|
||||
public class ExecuteContext {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private Agent agent;
|
||||
private Integer chatId;
|
||||
private Long queryId;
|
||||
private boolean saveAnswer;
|
||||
private SemanticParseInfo parseInfo;
|
||||
}
|
||||
@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatParseContext {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Agent agent;
|
||||
public class ParseContext {
|
||||
private User user;
|
||||
private String queryText;
|
||||
private Agent agent;
|
||||
private Integer chatId;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaMapInfo mapInfo;
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
if (agent == null) {
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
@@ -28,8 +28,8 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int recommend_dimension_size = 5;
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|
||||
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
||||
return;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
@@ -9,6 +9,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
*/
|
||||
public interface ExecuteResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
|
||||
void process(ExecuteContext executeContext, QueryResult queryResult);
|
||||
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -60,15 +60,15 @@ import org.springframework.util.CollectionUtils;
|
||||
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
|
||||
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
||||
|| !aggregatorConfig.getEnableRatio()
|
||||
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||
return;
|
||||
}
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
|
||||
semanticParseInfo, queryResult);
|
||||
queryResult.setAggregateInfo(aggregateInfo);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
@@ -34,8 +34,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
fillSimilarMetric(chatExecuteContext.getParseInfo());
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
fillSimilarMetric(executeContext.getParseInfo());
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
@@ -19,7 +19,7 @@ import java.util.List;
|
||||
public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
||||
return;
|
||||
@@ -33,7 +33,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
//1. set entity info
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, parseContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ParseResultProcessor {
|
||||
|
||||
void process(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
void process(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
@@ -25,15 +25,15 @@ import java.util.stream.Collectors;
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
|
||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
|
||||
chatParseContext.getAgent().getId());
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
|
||||
parseContext.getAgent().getId());
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
updateChatQuery(chatQueryDO);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(
|
||||
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
||||
|
||||
@@ -6,11 +6,10 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -32,20 +31,20 @@ import javax.validation.Valid;
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@PostMapping("search")
|
||||
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.search(chatParseReq);
|
||||
return chatQueryService.search(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("parse")
|
||||
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performParsing(chatParseReq);
|
||||
return chatQueryService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("execute")
|
||||
@@ -53,7 +52,7 @@ public class ChatQueryController {
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
@@ -62,7 +61,7 @@ public class ChatQueryController {
|
||||
throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
chatParseReq.setUser(user);
|
||||
ParseResp parseResp = chatService.performParsing(chatParseReq);
|
||||
ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
|
||||
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
throw new InvalidArgumentException("parser error,no selectedParses");
|
||||
@@ -72,27 +71,20 @@ public class ChatQueryController {
|
||||
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
|
||||
chatExecuteReq.setQueryId(parseResp.getQueryId());
|
||||
chatExecuteReq.setParseId(semanticParseInfo.getId());
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryContext")
|
||||
public Object queryContext(@RequestBody QueryNLReq queryCtx,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryContext(queryCtx.getChatId());
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryData")
|
||||
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@PostMapping("queryDimensionValue")
|
||||
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextService {
|
||||
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -12,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatService {
|
||||
public interface ChatQueryService {
|
||||
|
||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||
|
||||
@@ -24,8 +23,6 @@ public interface ChatService {
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
SemanticParseInfo queryContext(Integer chatId);
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
@@ -36,7 +36,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
@@ -103,7 +103,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
chatService.parseAndExecute(-1, agent.getId(), example);
|
||||
chatQueryService.parseAndExecute(-1, agent.getId(), example);
|
||||
} catch (Exception e) {
|
||||
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatContextServiceImpl implements ChatContextService {
|
||||
|
||||
private ChatContextRepository chatContextRepository;
|
||||
|
||||
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
|
||||
this.chatContextRepository = chatContextRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatContext getOrCreateContext(Integer chatId) {
|
||||
return chatContextRepository.getOrCreateContext(chatId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
log.debug("save ChatContext {}", chatCtx);
|
||||
chatContextRepository.updateContext(chatCtx);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -6,15 +6,16 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
@@ -26,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -39,49 +40,51 @@ import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatServiceImpl implements ChatService {
|
||||
public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Autowired
|
||||
private ChatManageService chatManageService;
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
private ChatLayerService chatLayerService;
|
||||
@Autowired
|
||||
private RetrieveService retrieveService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
|
||||
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
return retrieveService.retrieve(queryNLReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(chatParseContext);
|
||||
for (ChatParser chatParser : chatParsers) {
|
||||
chatParser.parse(chatParseContext, parseResp);
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(parseContext);
|
||||
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
|
||||
chatQueryParser.parse(parseContext, parseResp);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(chatParseContext, parseResp);
|
||||
processor.process(parseContext, parseResp);
|
||||
}
|
||||
chatParseReq.setQueryText(chatParseContext.getQueryText());
|
||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
||||
chatParseReq.setQueryText(parseContext.getQueryText());
|
||||
parseResp.setQueryText(parseContext.getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
chatManageService.updateParseCostTime(parseResp);
|
||||
return parseResp;
|
||||
@@ -90,9 +93,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
@Override
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatExecutor chatExecutor : chatExecutors) {
|
||||
queryResult = chatExecutor.execute(chatExecuteContext);
|
||||
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
||||
queryResult = chatQueryExecutor.execute(executeContext);
|
||||
if (queryResult != null) {
|
||||
break;
|
||||
}
|
||||
@@ -100,7 +103,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
|
||||
if (queryResult != null) {
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(chatExecuteContext, queryResult);
|
||||
processor.process(executeContext, queryResult);
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
@@ -125,34 +128,36 @@ public class ChatServiceImpl implements ChatService {
|
||||
executeReq.setQueryId(parseResp.getQueryId());
|
||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
executeReq.setQueryText(queryText);
|
||||
executeReq.setChatId(parseResp.getChatId());
|
||||
executeReq.setChatId(chatId);
|
||||
executeReq.setUser(User.getFakeUser());
|
||||
executeReq.setAgentId(agentId);
|
||||
executeReq.setSaveAnswer(true);
|
||||
return performExecution(executeReq);
|
||||
}
|
||||
|
||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = new ParseContext();
|
||||
BeanMapper.mapper(chatParseReq, parseContext);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
chatParseContext.setAgent(agent);
|
||||
return chatParseContext;
|
||||
parseContext.setAgent(agent);
|
||||
return parseContext;
|
||||
}
|
||||
|
||||
private void supplyMapInfo(ChatParseContext chatParseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryNLReq);
|
||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
||||
private void supplyMapInfo(ParseContext parseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp mapResp = chatLayerService.performMapping(queryNLReq);
|
||||
parseContext.setMapInfo(mapResp.getMapInfo());
|
||||
}
|
||||
|
||||
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
|
||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ExecuteContext executeContext = new ExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
chatExecuteContext.setParseInfo(parseInfo);
|
||||
return chatExecuteContext;
|
||||
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
||||
executeContext.setAgent(agent);
|
||||
executeContext.setParseInfo(parseInfo);
|
||||
return executeContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -163,12 +168,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
QueryDataReq queryData = new QueryDataReq();
|
||||
BeanMapper.mapper(chatQueryDataReq, queryData);
|
||||
queryData.setParseInfo(parseInfo);
|
||||
return chatQueryService.executeDirectQuery(queryData, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo queryContext(Integer chatId) {
|
||||
return chatQueryService.queryContext(chatId);
|
||||
return chatLayerService.executeDirectQuery(queryData, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -176,7 +176,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
Integer agentId = dimensionValueReq.getAgentId();
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
|
||||
return chatLayerService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
@@ -16,8 +16,8 @@ import java.util.List;
|
||||
public class ComponentFactory {
|
||||
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
||||
private static List<ChatParser> chatParsers = new ArrayList<>();
|
||||
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
|
||||
private static List<ChatQueryParser> chatQueryParsers = new ArrayList<>();
|
||||
private static List<ChatQueryExecutor> chatQueryExecutors = new ArrayList<>();
|
||||
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
||||
|
||||
public static List<ParseResultProcessor> getParseProcessors() {
|
||||
@@ -30,14 +30,14 @@ public class ComponentFactory {
|
||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||
}
|
||||
|
||||
public static List<ChatParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatParsers)
|
||||
? init(ChatParser.class, chatParsers) : chatParsers;
|
||||
public static List<ChatQueryParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatQueryParsers)
|
||||
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
|
||||
}
|
||||
|
||||
public static List<ChatExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatExecutors)
|
||||
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
|
||||
public static List<ChatQueryExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatQueryExecutors)
|
||||
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
|
||||
}
|
||||
|
||||
public static List<PluginRecognizer> getPluginRecognizers() {
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
|
||||
return buildText2SqlQueryReq(parseContext, null);
|
||||
}
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||
QueryNLReq queryNLReq = new QueryNLReq();
|
||||
BeanMapper.mapper(chatParseContext, queryNLReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
BeanMapper.mapper(parseContext, queryNLReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (agent == null) {
|
||||
return queryNLReq;
|
||||
}
|
||||
@@ -39,6 +44,9 @@ public class QueryReqConverter {
|
||||
}
|
||||
queryNLReq.setModelConfig(agent.getModelConfig());
|
||||
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
||||
if (chatCtx != null) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
return queryNLReq;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
|
||||
|
||||
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper">
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
|
||||
|
||||
<resultMap id="ChatContextDO"
|
||||
type="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO">
|
||||
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
|
||||
<id column="chat_id" property="chatId"/>
|
||||
<result column="modified_at" property="modifiedAt"/>
|
||||
<result column="user" property="user"/>
|
||||
@@ -20,7 +20,7 @@
|
||||
from s2_chat_context where chat_id=#{chatId} limit 1
|
||||
</select>
|
||||
|
||||
<insert id="addContext" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO" >
|
||||
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
|
||||
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
|
||||
</insert>
|
||||
<update id="updateContext">
|
||||
@@ -3,9 +3,9 @@
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
|
||||
|
||||
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper">
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
|
||||
|
||||
<resultMap id="Statistics" type="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
|
||||
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||
<id column="question_id" property="questionId"/>
|
||||
<result column="chat_id" property="chatId"/>
|
||||
<result column="user_name" property="userName"/>
|
||||
@@ -16,7 +16,7 @@
|
||||
<result column="create_time" property="createTime"/>
|
||||
</resultMap>
|
||||
|
||||
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
|
||||
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||
insert into s2_chat_statistics
|
||||
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
|
||||
values
|
||||
@@ -11,7 +11,6 @@ import lombok.Data;
|
||||
public class ExecuteQueryReq {
|
||||
private User user;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private SemanticParseInfo parseInfo;
|
||||
private boolean saveAnswer;
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -18,7 +19,6 @@ import java.util.Set;
|
||||
@Data
|
||||
public class QueryNLReq {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Set<Long> dataSetIds = Sets.newHashSet();
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
@@ -30,4 +30,5 @@ public class QueryNLReq {
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Data
|
||||
public class ParseResp {
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private Long queryId;
|
||||
private ParseState state = ParseState.PENDING;
|
||||
@@ -24,8 +23,7 @@ public class ParseResp {
|
||||
FAILED
|
||||
}
|
||||
|
||||
public ParseResp(Integer chatId, String queryText) {
|
||||
this.chatId = chatId;
|
||||
public ParseResp(String queryText) {
|
||||
this.queryText = queryText;
|
||||
parseTimeCost.setParseStartTime(System.currentTimeMillis());
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
@@ -35,7 +36,6 @@ import java.util.stream.Collectors;
|
||||
public class ChatQueryContext {
|
||||
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Set<Long> dataSetIds;
|
||||
private Map<Long, List<Long>> modelIdToDataSetIds;
|
||||
private User user;
|
||||
@@ -54,6 +54,7 @@ public class ChatQueryContext {
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
|
||||
@@ -11,7 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -29,7 +28,7 @@ import java.util.stream.Collectors;
|
||||
public class QueryTypeParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
User user = chatQueryContext.getUser();
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.tencent.supersonic.headless.chat.parser;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
/**
|
||||
@@ -10,5 +9,5 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
*/
|
||||
public interface SemanticParser {
|
||||
|
||||
void parse(ChatQueryContext chatQueryContext, ChatContext chatContext);
|
||||
void parse(ChatQueryContext chatQueryContext);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -23,7 +22,7 @@ import org.apache.commons.collections.MapUtils;
|
||||
public class LLMSqlParser implements SemanticParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||
public void parse(ChatQueryContext queryCtx) {
|
||||
try {
|
||||
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
|
||||
//1.determine whether to skip this parser.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
@@ -41,7 +40,7 @@ public class AggregateTypeParser implements SemanticParser {
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
String queryText = chatQueryContext.getQueryText();
|
||||
AggregateConf aggregateConf = resolveAggregateConf(queryText);
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
|
||||
@@ -43,11 +42,11 @@ public class ContextInheritParser implements SemanticParser {
|
||||
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
if (!shouldInherit(chatQueryContext)) {
|
||||
return;
|
||||
}
|
||||
Long dataSetId = getMatchedDataSet(chatQueryContext, chatContext);
|
||||
Long dataSetId = getMatchedDataSet(chatQueryContext);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
@@ -55,10 +54,11 @@ public class ContextInheritParser implements SemanticParser {
|
||||
List<SchemaElementMatch> elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
|
||||
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
|
||||
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
|
||||
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo().getElementMatches()) {
|
||||
SchemaElementType matchType = match.getElement().getType();
|
||||
// mutual exclusive element types should not be inherited
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
|
||||
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
|
||||
chatQueryContext.getContextParseInfo().getQueryMode());
|
||||
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
|
||||
match.setInherited(true);
|
||||
matchesToInherit.add(match);
|
||||
@@ -68,7 +68,7 @@ public class ContextInheritParser implements SemanticParser {
|
||||
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext, chatContext);
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
|
||||
continue;
|
||||
}
|
||||
@@ -108,8 +108,8 @@ public class ContextInheritParser implements SemanticParser {
|
||||
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
|
||||
}
|
||||
|
||||
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
Long dataSetId = chatContext.getParseInfo().getDataSetId();
|
||||
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) {
|
||||
Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId();
|
||||
if (dataSetId == null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
@@ -24,7 +23,7 @@ public class RuleSqlParser implements SemanticParser {
|
||||
);
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext chatQueryContext) {
|
||||
if (!chatQueryContext.getText2SQLType().enableRule()) {
|
||||
return;
|
||||
}
|
||||
@@ -34,11 +33,11 @@ public class RuleSqlParser implements SemanticParser {
|
||||
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
|
||||
for (RuleSemanticQuery query : queries) {
|
||||
query.fillParseInfo(chatQueryContext, chatContext);
|
||||
query.fillParseInfo(chatQueryContext);
|
||||
chatQueryContext.getCandidateQueries().add(query);
|
||||
}
|
||||
}
|
||||
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext, chatContext));
|
||||
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.rule;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
@@ -42,7 +41,7 @@ public class TimeRangeParser implements SemanticParser {
|
||||
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
|
||||
|
||||
@Override
|
||||
public void parse(ChatQueryContext queryContext, ChatContext chatContext) {
|
||||
public void parse(ChatQueryContext queryContext) {
|
||||
String queryText = queryContext.getQueryText();
|
||||
DateConf dateConf = parseRecent(queryText);
|
||||
if (dateConf == null) {
|
||||
@@ -59,14 +58,14 @@ public class TimeRangeParser implements SemanticParser {
|
||||
query.getParseInfo().setScore(query.getParseInfo().getScore()
|
||||
+ dateConf.getDetectWord().length());
|
||||
}
|
||||
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
|
||||
} else if (QueryManager.containsRuleQuery(queryContext.getContextParseInfo().getQueryMode())) {
|
||||
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
|
||||
chatContext.getParseInfo().getQueryMode());
|
||||
queryContext.getContextParseInfo().getQueryMode());
|
||||
// inherit parse info from context
|
||||
chatContext.getParseInfo().setDateInfo(dateConf);
|
||||
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
|
||||
queryContext.getContextParseInfo().setDateInfo(dateConf);
|
||||
queryContext.getContextParseInfo().setScore(queryContext.getContextParseInfo().getScore()
|
||||
+ dateConf.getDetectWord().length());
|
||||
semanticQuery.setParseInfo(chatContext.getParseInfo());
|
||||
semanticQuery.setParseInfo(queryContext.getContextParseInfo());
|
||||
queryContext.getCandidateQueries().add(semanticQuery);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
|
||||
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
@@ -49,13 +48,13 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
|
||||
initS2SqlByStruct(semanticSchema);
|
||||
}
|
||||
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
parseInfo.setQueryMode(getQueryMode());
|
||||
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
|
||||
|
||||
fillSchemaElement(parseInfo, semanticSchema);
|
||||
fillScore(parseInfo);
|
||||
fillDateConf(parseInfo, chatContext.getParseInfo());
|
||||
fillDateConf(parseInfo, chatQueryContext.getContextParseInfo());
|
||||
}
|
||||
|
||||
private void fillDateConf(SemanticParseInfo queryParseInfo, SemanticParseInfo chatParseInfo) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
|
||||
@@ -19,8 +18,8 @@ import java.util.stream.Collectors;
|
||||
public abstract class DetailListQuery extends DetailSemanticQuery {
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.LocalDate;
|
||||
@@ -35,8 +34,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
|
||||
parseInfo.setQueryType(QueryType.DETAIL);
|
||||
parseInfo.setLimit(DETAIL_MAX_RESULTS);
|
||||
|
||||
@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.time.LocalDate;
|
||||
@@ -36,8 +35,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
parseInfo.setLimit(METRIC_MAX_RESULTS);
|
||||
if (parseInfo.getDateInfo() == null) {
|
||||
DataSetSchema dataSetSchema =
|
||||
|
||||
@@ -6,7 +6,6 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@@ -50,8 +49,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
super.fillParseInfo(chatQueryContext, chatContext);
|
||||
public void fillParseInfo(ChatQueryContext chatQueryContext) {
|
||||
super.fillParseInfo(chatQueryContext);
|
||||
|
||||
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
|
||||
parseInfo.setScore(parseInfo.getScore() + 2.0);
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -24,7 +24,7 @@ import javax.servlet.http.HttpServletResponse;
|
||||
public class ChatQueryApiController {
|
||||
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
private ChatLayerService chatLayerService;
|
||||
|
||||
@Autowired
|
||||
private RetrieveService retrieveService;
|
||||
@@ -45,7 +45,7 @@ public class ChatQueryApiController {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
queryNLReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatQueryService.performMapping(queryNLReq);
|
||||
return chatLayerService.performMapping(queryNLReq);
|
||||
}
|
||||
|
||||
@PostMapping("/chat/parse")
|
||||
@@ -53,7 +53,7 @@ public class ChatQueryApiController {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
queryNLReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatQueryService.performParsing(queryNLReq);
|
||||
return chatLayerService.performParsing(queryNLReq);
|
||||
}
|
||||
|
||||
@PostMapping("/chat")
|
||||
@@ -61,7 +61,7 @@ public class ChatQueryApiController {
|
||||
HttpServletRequest request,
|
||||
HttpServletResponse response) throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
ParseResp parseResp = chatQueryService.performParsing(queryNLReq);
|
||||
ParseResp parseResp = chatLayerService.performParsing(queryNLReq);
|
||||
if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
|
||||
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
|
||||
QuerySqlReq sqlReq = new QuerySqlReq();
|
||||
|
||||
@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.facade.rest;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
@@ -20,14 +20,14 @@ import javax.servlet.http.HttpServletResponse;
|
||||
public class MetaDiscoveryApiController {
|
||||
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
private ChatLayerService chatLayerService;
|
||||
|
||||
@PostMapping("map")
|
||||
public Object map(@RequestBody QueryMapReq queryMapReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
queryMapReq.setUser(user);
|
||||
return chatQueryService.map(queryMapReq);
|
||||
return chatLayerService.map(queryMapReq);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -33,7 +33,7 @@ public class SqlQueryApiController {
|
||||
private SemanticLayerService semanticLayerService;
|
||||
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
private ChatLayerService chatLayerService;
|
||||
|
||||
@PostMapping("/sql")
|
||||
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
|
||||
@@ -42,7 +42,7 @@ public class SqlQueryApiController {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
String sql = querySqlReq.getSql();
|
||||
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
|
||||
chatQueryService.correct(querySqlReq, user);
|
||||
chatLayerService.correct(querySqlReq, user);
|
||||
return semanticLayerService.queryByReq(querySqlReq, user);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ public class SqlQueryApiController {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
|
||||
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
|
||||
chatQueryService.correct(querySqlReq, user);
|
||||
chatLayerService.correct(querySqlReq, user);
|
||||
return querySqlReq;
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
@@ -82,7 +82,7 @@ public class SqlQueryApiController {
|
||||
QuerySqlReq querySqlReq = new QuerySqlReq();
|
||||
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
|
||||
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
|
||||
chatQueryService.correct(querySqlReq, user);
|
||||
chatLayerService.correct(querySqlReq, user);
|
||||
return querySqlReq;
|
||||
}).collect(Collectors.toList());
|
||||
List<SemanticQueryResp> semanticQueryRespList = new ArrayList<>();
|
||||
@@ -104,7 +104,7 @@ public class SqlQueryApiController {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
String sql = querySqlReq.getSql();
|
||||
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
|
||||
return chatQueryService.validate(querySqlReq, user);
|
||||
return chatLayerService.validate(querySqlReq, user);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.facade.service;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
|
||||
@@ -16,14 +15,12 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
/***dd
|
||||
* SemanticLayerService for query and search
|
||||
*/
|
||||
public interface ChatQueryService {
|
||||
public interface ChatLayerService {
|
||||
|
||||
MapResp performMapping(QueryNLReq queryNLReq);
|
||||
|
||||
ParseResp performParsing(QueryNLReq queryNLReq);
|
||||
|
||||
SemanticParseInfo queryContext(Integer chatId);
|
||||
|
||||
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
@@ -22,7 +22,6 @@ import com.tencent.supersonic.headless.chat.mapper.MatchText;
|
||||
import com.tencent.supersonic.headless.chat.mapper.ModelWithSemanticType;
|
||||
import com.tencent.supersonic.headless.chat.mapper.SearchMatchStrategy;
|
||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
||||
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
|
||||
import com.tencent.supersonic.headless.server.web.service.DataSetService;
|
||||
import com.tencent.supersonic.headless.server.web.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -52,9 +51,6 @@ public class RetrieveServiceImpl implements RetrieveService {
|
||||
@Autowired
|
||||
private DataSetService dataSetService;
|
||||
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
|
||||
@@ -135,7 +131,7 @@ public class RetrieveServiceImpl implements RetrieveService {
|
||||
|
||||
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
|
||||
|
||||
Long contextDataset = chatContextService.getContextDataset(queryCtx.getChatId());
|
||||
Long contextDataset = queryCtx.getContextParseInfo().getDataSetId();
|
||||
|
||||
log.debug("possibleDataSets:{},dataSetInfoStat:{},contextDataset:{}",
|
||||
possibleDataSets, dataSetInfoStat, contextDataset);
|
||||
|
||||
@@ -44,7 +44,6 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
|
||||
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
|
||||
@@ -57,12 +56,11 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
|
||||
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
|
||||
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
|
||||
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
|
||||
import com.tencent.supersonic.headless.server.web.service.DataSetService;
|
||||
import com.tencent.supersonic.headless.server.web.service.SchemaService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -98,14 +96,12 @@ import java.util.stream.Collectors;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
public class S2ChatLayerService implements ChatLayerService {
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private SchemaService schemaService;
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
@Autowired
|
||||
private KnowledgeBaseService knowledgeBaseService;
|
||||
@Autowired
|
||||
private DataSetService dataSetService;
|
||||
@@ -141,14 +137,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(QueryNLReq queryNLReq) {
|
||||
ParseResp parseResult = new ParseResp(queryNLReq.getChatId(), queryNLReq.getQueryText());
|
||||
// build queryContext and chatContext
|
||||
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
|
||||
// build queryContext
|
||||
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
||||
|
||||
// in order to support multi-turn conversation, chat context is needed
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(queryNLReq.getChatId());
|
||||
|
||||
chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult);
|
||||
chatWorkflowEngine.execute(queryCtx, parseResult);
|
||||
|
||||
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
@@ -173,12 +166,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return queryCtx;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo queryContext(Integer chatId) {
|
||||
ChatContext context = chatContextService.getOrCreateContext(chatId);
|
||||
return context.getParseInfo();
|
||||
}
|
||||
|
||||
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
|
||||
//"style='流行'"->"style in ['流行','爱国']"
|
||||
@Override
|
||||
@@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
|
||||
import com.tencent.supersonic.headless.server.web.service.SchemaService;
|
||||
@@ -40,7 +39,7 @@ import java.util.stream.Collectors;
|
||||
public class ParseInfoProcessor implements ResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext) {
|
||||
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
|
||||
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
|
||||
if (CollectionUtils.isEmpty(candidateQueries)) {
|
||||
return;
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package com.tencent.supersonic.headless.server.processor;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
|
||||
/**
|
||||
@@ -9,6 +8,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
*/
|
||||
public interface ResultProcessor {
|
||||
|
||||
void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext);
|
||||
void process(ParseResp parseResp, ChatQueryContext chatQueryContext);
|
||||
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
|
||||
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
|
||||
@@ -39,7 +38,7 @@ public class ChatWorkflowEngine {
|
||||
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
|
||||
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||
|
||||
public void execute(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
||||
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
||||
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
|
||||
switch (queryCtx.getChatWorkflowState()) {
|
||||
@@ -54,7 +53,7 @@ public class ChatWorkflowEngine {
|
||||
}
|
||||
break;
|
||||
case PARSING:
|
||||
performParsing(queryCtx, chatCtx);
|
||||
performParsing(queryCtx);
|
||||
if (queryCtx.getCandidateQueries().size() == 0) {
|
||||
parseResult.setState(ParseResp.ParseState.FAILED);
|
||||
parseResult.setErrorMsg("No semantic queries can be parsed out.");
|
||||
@@ -75,7 +74,7 @@ public class ChatWorkflowEngine {
|
||||
break;
|
||||
case PROCESSING:
|
||||
default:
|
||||
performProcessing(queryCtx, chatCtx, parseResult);
|
||||
performProcessing(queryCtx, parseResult);
|
||||
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
|
||||
parseResult.setState(ParseResp.ParseState.COMPLETED);
|
||||
}
|
||||
@@ -92,9 +91,9 @@ public class ChatWorkflowEngine {
|
||||
}
|
||||
}
|
||||
|
||||
private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
|
||||
private void performParsing(ChatQueryContext queryCtx) {
|
||||
semanticParsers.forEach(parser -> {
|
||||
parser.parse(queryCtx, chatCtx);
|
||||
parser.parse(queryCtx);
|
||||
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
|
||||
});
|
||||
}
|
||||
@@ -116,9 +115,9 @@ public class ChatWorkflowEngine {
|
||||
}
|
||||
}
|
||||
|
||||
private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
|
||||
private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
resultProcessors.forEach(processor -> {
|
||||
processor.process(parseResult, queryCtx, chatCtx);
|
||||
processor.process(parseResult, queryCtx);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.tencent.supersonic.headless.server.web.service;
|
||||
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
|
||||
public interface ChatContextService {
|
||||
|
||||
/***
|
||||
* get the model from context
|
||||
* @param chatId
|
||||
* @return
|
||||
*/
|
||||
Long getContextDataset(Integer chatId);
|
||||
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
package com.tencent.supersonic.headless.server.web.service.impl;
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatContextServiceImpl implements ChatContextService {
|
||||
|
||||
private ChatContextRepository chatContextRepository;
|
||||
|
||||
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
|
||||
this.chatContextRepository = chatContextRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Long getContextDataset(Integer chatId) {
|
||||
if (Objects.isNull(chatId)) {
|
||||
return null;
|
||||
}
|
||||
ChatContext chatContext = getOrCreateContext(chatId);
|
||||
if (Objects.isNull(chatContext)) {
|
||||
return null;
|
||||
}
|
||||
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
|
||||
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
|
||||
return originalSemanticParse.getDataSetId();
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatContext getOrCreateContext(Integer chatId) {
|
||||
return chatContextRepository.getOrCreateContext(chatId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
log.debug("save ChatContext {}", chatCtx);
|
||||
chatContextRepository.updateContext(chatCtx);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -44,7 +44,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.TagItem;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQueryDefaultConfigDO;
|
||||
@@ -111,7 +111,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
||||
|
||||
private TagMetaService tagMetaService;
|
||||
|
||||
private ChatQueryService chatQueryService;
|
||||
private ChatLayerService chatLayerService;
|
||||
|
||||
public MetricServiceImpl(MetricRepository metricRepository,
|
||||
ModelService modelService,
|
||||
@@ -121,7 +121,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
||||
ApplicationEventPublisher eventPublisher,
|
||||
DimensionService dimensionService,
|
||||
TagMetaService tagMetaService,
|
||||
@Lazy ChatQueryService chatQueryService) {
|
||||
@Lazy ChatLayerService chatLayerService) {
|
||||
this.metricRepository = metricRepository;
|
||||
this.modelService = modelService;
|
||||
this.aliasGenerateHelper = aliasGenerateHelper;
|
||||
@@ -130,7 +130,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
||||
this.dataSetService = dataSetService;
|
||||
this.dimensionService = dimensionService;
|
||||
this.tagMetaService = tagMetaService;
|
||||
this.chatQueryService = chatQueryService;
|
||||
this.chatLayerService = chatLayerService;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -299,7 +299,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
|
||||
queryMapReq.setQueryText(pageMetricReq.getKey());
|
||||
queryMapReq.setUser(user);
|
||||
queryMapReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
Map<String, DataSetMapInfo> dataSetMapInfoMap = mapMeta.getDataSetMapInfo();
|
||||
if (CollectionUtils.isEmpty(dataSetMapInfoMap)) {
|
||||
return metricRespPageInfo;
|
||||
|
||||
@@ -16,7 +16,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
|
||||
import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository;
|
||||
import com.tencent.supersonic.headless.server.utils.AliasGenerateHelper;
|
||||
@@ -77,10 +77,10 @@ public class MetricServiceImplTest {
|
||||
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
|
||||
DimensionService dimensionService = Mockito.mock(DimensionService.class);
|
||||
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
|
||||
ChatQueryService chatQueryService = Mockito.mock(ChatQueryService.class);
|
||||
ChatLayerService chatLayerService = Mockito.mock(ChatLayerService.class);
|
||||
return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper,
|
||||
collectService, dataSetService, eventPublisher, dimensionService,
|
||||
tagMetaService, chatQueryService);
|
||||
tagMetaService, chatLayerService);
|
||||
}
|
||||
|
||||
private MetricReq buildMetricReq() {
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
|
||||
@@ -75,7 +75,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
@Autowired
|
||||
protected TagObjectService tagObjectService;
|
||||
@Autowired
|
||||
protected ChatService chatService;
|
||||
protected ChatQueryService chatQueryService;
|
||||
@Autowired
|
||||
protected ChatManageService chatManageService;
|
||||
@Autowired
|
||||
|
||||
@@ -137,11 +137,11 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
public void addSampleChats(Integer agentId) {
|
||||
Long chatId = chatManageService.addChat(user, "样例对话1", agentId);
|
||||
|
||||
chatService.parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数");
|
||||
chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
|
||||
chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
|
||||
chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
|
||||
chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
|
||||
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数");
|
||||
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
|
||||
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
|
||||
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
|
||||
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
|
||||
}
|
||||
|
||||
private Integer addAgent(long dataSetId) {
|
||||
|
||||
@@ -51,12 +51,12 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\
|
||||
|
||||
### chat-server SPIs
|
||||
|
||||
com.tencent.supersonic.chat.server.parser.ChatParser=\
|
||||
com.tencent.supersonic.chat.server.parser.ChatQueryParser=\
|
||||
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
|
||||
com.tencent.supersonic.chat.server.parser.NL2SQLParser,\
|
||||
com.tencent.supersonic.chat.server.parser.PlainTextParser
|
||||
|
||||
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
|
||||
com.tencent.supersonic.chat.server.executor.ChatQueryExecutor=\
|
||||
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
|
||||
com.tencent.supersonic.chat.server.executor.SqlExecutor,\
|
||||
com.tencent.supersonic.chat.server.executor.PlainTextExecutor
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.BaseApplication;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -28,7 +28,7 @@ public class BaseTest extends BaseApplication {
|
||||
protected final String period = "DAY";
|
||||
|
||||
@Autowired
|
||||
protected ChatService chatService;
|
||||
protected ChatQueryService chatQueryService;
|
||||
@Autowired
|
||||
protected AgentService agentService;
|
||||
|
||||
@@ -37,33 +37,34 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
SemanticParseInfo semanticParseInfo = parseResp.getSelectedParses().get(0);
|
||||
ChatExecuteReq request = ChatExecuteReq.builder()
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.user(DataUtils.getUser())
|
||||
.parseId(semanticParseInfo.getId())
|
||||
.queryId(parseResp.getQueryId())
|
||||
.chatId(chatId)
|
||||
.saveAnswer(true)
|
||||
.build();
|
||||
QueryResult queryResult = chatService.performExecution(request);
|
||||
QueryResult queryResult = chatQueryService.performExecution(request);
|
||||
queryResult.setChatContext(semanticParseInfo);
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
protected QueryResult submitNewChat(String queryText, Integer agentId) throws Exception {
|
||||
ParseResp parseResp = submitParse(queryText, agentId);
|
||||
int chatId = 10;
|
||||
ParseResp parseResp = submitParse(queryText, agentId, chatId);
|
||||
|
||||
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
|
||||
ChatExecuteReq request = ChatExecuteReq.builder()
|
||||
.chatId(parseResp.getChatId())
|
||||
.queryText(parseResp.getQueryText())
|
||||
.user(DataUtils.getUser())
|
||||
.parseId(parseInfo.getId())
|
||||
.agentId(agentId)
|
||||
.chatId(chatId)
|
||||
.queryId(parseResp.getQueryId())
|
||||
.saveAnswer(false)
|
||||
.build();
|
||||
|
||||
QueryResult result = chatService.performExecution(request);
|
||||
QueryResult result = chatQueryService.performExecution(request);
|
||||
result.setChatContext(parseInfo);
|
||||
return result;
|
||||
}
|
||||
@@ -74,7 +75,7 @@ public class BaseTest extends BaseApplication {
|
||||
}
|
||||
ChatParseReq chatParseReq = DataUtils.getChatParseReq(chatId, queryText);
|
||||
chatParseReq.setAgentId(agentId);
|
||||
return chatService.performParsing(chatParseReq);
|
||||
return chatQueryService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
protected ParseResp submitParse(String queryText, Integer agentId) {
|
||||
@@ -83,7 +84,7 @@ public class BaseTest extends BaseApplication {
|
||||
|
||||
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
|
||||
ChatParseReq chatParseReq = DataUtils.getChatParseReqWithAgent(10, queryText, agentId);
|
||||
return chatService.performParsing(chatParseReq);
|
||||
return chatQueryService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import org.junit.Assert;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
@@ -15,7 +15,7 @@ import java.util.Collections;
|
||||
public class MetaDiscoveryTest extends BaseTest {
|
||||
|
||||
@Autowired
|
||||
protected ChatQueryService chatQueryService;
|
||||
protected ChatLayerService chatLayerService;
|
||||
|
||||
@Test
|
||||
public void testGetMapMeta() throws Exception {
|
||||
@@ -24,7 +24,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集"));
|
||||
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
|
||||
Assertions.assertNotNull(mapMeta);
|
||||
Assertions.assertNotEquals(0, mapMeta.getDataSetMapInfo().get("超音数数据集").getMapFields());
|
||||
@@ -39,7 +39,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
|
||||
queryMapReq.setQueryDataType(QueryDataType.TAG);
|
||||
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
Assert.assertNotNull(mapMeta);
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
|
||||
queryMapReq.setQueryDataType(QueryDataType.METRIC);
|
||||
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
Assert.assertNotNull(mapMeta);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user