mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
(improvement)(headless&chat)Move ChatContext from Headless module to Chat module.
This commit is contained in:
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatExecutor {
|
||||
|
||||
QueryResult execute(ChatExecuteContext chatExecuteContext);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public interface ChatQueryExecutor {
|
||||
|
||||
QueryResult execute(ExecuteContext executeContext);
|
||||
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -23,7 +23,7 @@ import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
public class PlainTextExecutor implements ChatExecutor {
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "#Role: You are a nice person to talk to.\n"
|
||||
@@ -34,34 +34,34 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
+ "#Your response: ";
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
|
||||
chatExecuteContext.getQueryText());
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
result.setQueryState(QueryState.SUCCESS);
|
||||
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
|
||||
result.setQueryMode(executeContext.getParseInfo().getQueryMode());
|
||||
result.setTextResult(response.content().text());
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
|
||||
private String getHistoryInputs(ExecuteContext executeContext) {
|
||||
StringBuilder historyInput = new StringBuilder();
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
||||
@@ -70,7 +70,7 @@ public class PlainTextExecutor implements ChatExecutor {
|
||||
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
|
||||
|
||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
|
||||
List<ParseResp> parseResps = getHistoryParseResult(executeContext.getChatId(), 5);
|
||||
parseResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
|
||||
@@ -2,15 +2,15 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
public class PluginExecutor implements ChatExecutor {
|
||||
public class PluginExecutor implements ChatQueryExecutor {
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.executor;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ResultFormatter;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -12,10 +12,10 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
|
||||
import com.tencent.supersonic.headless.chat.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import lombok.SneakyThrows;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
|
||||
@@ -25,12 +25,12 @@ import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public class SqlExecutor implements ChatExecutor {
|
||||
public class SqlExecutor implements ChatQueryExecutor {
|
||||
|
||||
@SneakyThrows
|
||||
@Override
|
||||
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
|
||||
QueryResult queryResult = doExecute(chatExecuteContext);
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
QueryResult queryResult = doExecute(executeContext);
|
||||
|
||||
if (queryResult != null) {
|
||||
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
|
||||
@@ -41,13 +41,13 @@ public class SqlExecutor implements ChatExecutor {
|
||||
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
|
||||
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
|
||||
memoryService.createMemory(ChatMemoryDO.builder()
|
||||
.agentId(chatExecuteContext.getAgentId())
|
||||
.agentId(executeContext.getAgent().getId())
|
||||
.status(MemoryStatus.PENDING)
|
||||
.question(chatExecuteContext.getQueryText())
|
||||
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
|
||||
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
|
||||
.createdBy(chatExecuteContext.getUser().getName())
|
||||
.updatedBy(chatExecuteContext.getUser().getName())
|
||||
.question(executeContext.getQueryText())
|
||||
.s2sql(executeContext.getParseInfo().getSqlInfo().getParsedS2SQL())
|
||||
.dbSchema(buildSchemaStr(executeContext.getParseInfo()))
|
||||
.createdBy(executeContext.getUser().getName())
|
||||
.updatedBy(executeContext.getUser().getName())
|
||||
.createdAt(new Date())
|
||||
.build());
|
||||
}
|
||||
@@ -57,12 +57,12 @@ public class SqlExecutor implements ChatExecutor {
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private QueryResult doExecute(ChatExecuteContext chatExecuteContext) {
|
||||
private QueryResult doExecute(ExecuteContext executeContext) {
|
||||
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
|
||||
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
|
||||
SemanticParseInfo parseInfo = executeContext.getParseInfo();
|
||||
if (Objects.isNull(parseInfo.getSqlInfo())
|
||||
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
|
||||
return null;
|
||||
@@ -74,8 +74,10 @@ public class SqlExecutor implements ChatExecutor {
|
||||
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
|
||||
sqlReq.setDataSetId(parseInfo.getDataSetId());
|
||||
long startTime = System.currentTimeMillis();
|
||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, chatExecuteContext.getUser());
|
||||
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
|
||||
QueryResult queryResult = new QueryResult();
|
||||
queryResult.setChatContext(parseInfo);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
if (queryResp != null) {
|
||||
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
|
||||
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
|
||||
@@ -85,7 +87,6 @@ public class SqlExecutor implements ChatExecutor {
|
||||
queryResult.setQuerySql(queryResp.getSql());
|
||||
queryResult.setQueryResults(resultList);
|
||||
queryResult.setQueryColumns(columns);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
queryResult.setQueryState(QueryState.SUCCESS);
|
||||
|
||||
chatCtx.setParseInfo(parseInfo);
|
||||
@@ -94,7 +95,6 @@ public class SqlExecutor implements ChatExecutor {
|
||||
queryResult.setQueryState(QueryState.INVALID);
|
||||
queryResult.setQueryMode(parseInfo.getQueryMode());
|
||||
}
|
||||
queryResult.setChatContext(chatCtx.getParseInfo());
|
||||
return queryResult;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatParser {
|
||||
|
||||
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ChatQueryParser {
|
||||
|
||||
void parse(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -9,18 +9,18 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
public class NL2PluginParser implements ChatParser {
|
||||
public class NL2PluginParser implements ChatQueryParser {
|
||||
|
||||
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.getAgent().containsPluginTool()) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.getAgent().containsPluginTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
pluginRecognizers.forEach(pluginRecognizer -> {
|
||||
pluginRecognizer.recognize(chatParseContext, parseResp);
|
||||
pluginRecognizer.recognize(parseContext, parseResp);
|
||||
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
|
||||
JsonUtil.toString(parseResp));
|
||||
});
|
||||
|
||||
@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.server.parser;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
@@ -17,7 +18,8 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
@@ -42,7 +44,7 @@ import java.util.stream.Collectors;
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
|
||||
@Slf4j
|
||||
public class NL2SQLParser implements ChatParser {
|
||||
public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
@@ -73,27 +75,30 @@ public class NL2SQLParser implements ChatParser {
|
||||
+ "#Response: ";
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
return;
|
||||
}
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
chatParseContext.getAgent().getModelConfig());
|
||||
parseContext.getAgent().getModelConfig());
|
||||
|
||||
processMultiTurn(chatLanguageModel, chatParseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
|
||||
processMultiTurn(chatLanguageModel, parseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
|
||||
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
|
||||
chatParseContext.getQueryText(),
|
||||
parseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(),
|
||||
queryNLReq.getDynamicExemplars(),
|
||||
chatParseContext.getAgent().getExamples()));
|
||||
parseContext.getAgent().getExamples()));
|
||||
}
|
||||
parseResp.setState(text2SqlParseResp.getState());
|
||||
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
|
||||
@@ -155,9 +160,9 @@ public class NL2SQLParser implements ChatParser {
|
||||
parseInfo.setTextInfo(textBuilder.toString());
|
||||
}
|
||||
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
|
||||
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
|
||||
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
Boolean multiTurnConfig = agentMultiTurnConfig != null
|
||||
@@ -167,11 +172,11 @@ public class NL2SQLParser implements ChatParser {
|
||||
}
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp currentMapResult = chatQueryService.performMapping(queryNLReq);
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
|
||||
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
|
||||
List<ParseResp> historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1);
|
||||
if (historyParseResults.size() == 0) {
|
||||
return;
|
||||
}
|
||||
@@ -196,7 +201,7 @@ public class NL2SQLParser implements ChatParser {
|
||||
String rewrittenQuery = response.content().text();
|
||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
|
||||
|
||||
chatParseContext.setQueryText(rewrittenQuery);
|
||||
parseContext.setQueryText(rewrittenQuery);
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
|
||||
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Service("ChatParserConfig")
|
||||
@Service("ChatQueryParserConfig")
|
||||
@Slf4j
|
||||
public class ParserConfig extends ParameterConfig {
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
|
||||
public class PlainTextParser implements ChatParser {
|
||||
public class PlainTextParser implements ChatQueryParser {
|
||||
|
||||
@Override
|
||||
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (chatParseContext.getAgent().containsAnyTool()) {
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (parseContext.getAgent().containsAnyTool()) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.time.Instant;
|
||||
|
||||
@Data
|
||||
public class ChatContextDO implements Serializable {
|
||||
|
||||
private Integer chatId;
|
||||
private Instant modifiedAt;
|
||||
private String user;
|
||||
private String queryText;
|
||||
private String semanticParse;
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.Date;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@Getter
|
||||
@AllArgsConstructor
|
||||
public class StatisticsDO {
|
||||
/**
|
||||
* questionId
|
||||
*/
|
||||
private Long questionId;
|
||||
|
||||
/**
|
||||
* chatId
|
||||
*/
|
||||
private Long chatId;
|
||||
|
||||
/**
|
||||
* createTime
|
||||
*/
|
||||
private Date createTime;
|
||||
|
||||
/**
|
||||
* queryText
|
||||
*/
|
||||
private String queryText;
|
||||
|
||||
/**
|
||||
* userName
|
||||
*/
|
||||
private String userName;
|
||||
|
||||
|
||||
/**
|
||||
* interface
|
||||
*/
|
||||
private String interfaceName;
|
||||
|
||||
/**
|
||||
* cost
|
||||
*/
|
||||
private Integer cost;
|
||||
|
||||
private Integer type;
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
|
||||
@Mapper
|
||||
public interface ChatContextMapper {
|
||||
|
||||
ChatContextDO getContextByChatId(Integer chatId);
|
||||
|
||||
int updateContext(ChatContextDO contextDO);
|
||||
|
||||
int addContext(ChatContextDO contextDO);
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import org.apache.ibatis.annotations.Mapper;
|
||||
import org.apache.ibatis.annotations.Param;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
@Mapper
|
||||
public interface StatisticsMapper {
|
||||
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextRepository {
|
||||
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.tencent.supersonic.chat.server.persistence.repository.impl;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.context.annotation.Primary;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
@Repository
|
||||
@Primary
|
||||
@Slf4j
|
||||
public class ChatContextRepositoryImpl implements ChatContextRepository {
|
||||
|
||||
|
||||
private final ChatContextMapper chatContextMapper;
|
||||
|
||||
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
|
||||
this.chatContextMapper = chatContextMapper;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatContext getOrCreateContext(Integer chatId) {
|
||||
ChatContextDO context = chatContextMapper.getContextByChatId(chatId);
|
||||
if (context == null) {
|
||||
ChatContext chatContext = new ChatContext();
|
||||
chatContext.setChatId(chatId);
|
||||
return chatContext;
|
||||
}
|
||||
return cast(context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
ChatContextDO context = cast(chatCtx);
|
||||
if (chatContextMapper.getContextByChatId(chatCtx.getChatId()) == null) {
|
||||
chatContextMapper.addContext(context);
|
||||
} else {
|
||||
chatContextMapper.updateContext(context);
|
||||
}
|
||||
}
|
||||
|
||||
private ChatContext cast(ChatContextDO contextDO) {
|
||||
ChatContext chatContext = new ChatContext();
|
||||
chatContext.setChatId(contextDO.getChatId());
|
||||
chatContext.setUser(contextDO.getUser());
|
||||
chatContext.setQueryText(contextDO.getQueryText());
|
||||
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
|
||||
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
|
||||
SemanticParseInfo.class);
|
||||
chatContext.setParseInfo(semanticParseInfo);
|
||||
}
|
||||
return chatContext;
|
||||
}
|
||||
|
||||
private ChatContextDO cast(ChatContext chatContext) {
|
||||
ChatContextDO chatContextDO = new ChatContextDO();
|
||||
chatContextDO.setChatId(chatContext.getChatId());
|
||||
chatContextDO.setQueryText(chatContext.getQueryText());
|
||||
chatContextDO.setUser(chatContext.getUser());
|
||||
if (chatContext.getParseInfo() != null) {
|
||||
Gson g = new Gson();
|
||||
chatContextDO.setSemanticParse(g.toJson(chatContext.getParseInfo()));
|
||||
}
|
||||
return chatContextDO;
|
||||
}
|
||||
}
|
||||
@@ -183,7 +183,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
|
||||
public List<ParseResp> getContextualParseInfo(Integer chatId) {
|
||||
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
|
||||
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
|
||||
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
|
||||
ParseResp parseResp = new ParseResp(parseInfo.getQueryText());
|
||||
List<SemanticParseInfo> selectedParses = new ArrayList<>();
|
||||
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
|
||||
parseResp.setSelectedParses(selectedParses);
|
||||
|
||||
@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
@@ -52,9 +52,9 @@ public class PluginManager {
|
||||
@Autowired
|
||||
private EmbeddingService embeddingService;
|
||||
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
|
||||
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
|
||||
PluginService pluginService = ContextUtils.getBean(PluginService.class);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
Agent agent = parseContext.getAgent();
|
||||
List<ChatPlugin> plugins = pluginService.getPluginList();
|
||||
if (Objects.isNull(agent)) {
|
||||
return plugins;
|
||||
@@ -191,9 +191,9 @@ public class PluginManager {
|
||||
return String.valueOf(Integer.parseInt(id) / 1000);
|
||||
}
|
||||
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
||||
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
@@ -259,8 +259,8 @@ public class PluginManager {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
|
||||
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
|
||||
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
@@ -28,22 +28,22 @@ import java.util.Set;
|
||||
*/
|
||||
public abstract class PluginRecognizer {
|
||||
|
||||
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(chatParseContext)) {
|
||||
public void recognize(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!checkPreCondition(parseContext)) {
|
||||
return;
|
||||
}
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
|
||||
PluginRecallResult pluginRecallResult = recallPlugin(parseContext);
|
||||
if (pluginRecallResult == null) {
|
||||
return;
|
||||
}
|
||||
buildQuery(chatParseContext, parseResp, pluginRecallResult);
|
||||
buildQuery(parseContext, parseResp, pluginRecallResult);
|
||||
}
|
||||
|
||||
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
|
||||
public abstract boolean checkPreCondition(ParseContext parseContext);
|
||||
|
||||
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
|
||||
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
|
||||
|
||||
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
|
||||
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
|
||||
PluginRecallResult pluginRecallResult) {
|
||||
ChatPlugin plugin = pluginRecallResult.getPlugin();
|
||||
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
|
||||
@@ -52,21 +52,21 @@ public abstract class PluginRecognizer {
|
||||
}
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
chatParseContext, pluginRecallResult.getDistance());
|
||||
parseContext, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(plugin.getType());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||
}
|
||||
}
|
||||
|
||||
protected List<ChatPlugin> getPluginList(ChatParseContext chatParseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(chatParseContext);
|
||||
protected List<ChatPlugin> getPluginList(ParseContext parseContext) {
|
||||
return PluginManager.getPluginAgentCanSupport(parseContext);
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||
ChatParseContext chatParseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = chatParseContext.getQueryFilters();
|
||||
ParseContext parseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
}
|
||||
@@ -80,7 +80,7 @@ public abstract class PluginRecognizer {
|
||||
pluginParseResult.setPlugin(plugin);
|
||||
pluginParseResult.setQueryFilters(queryFilters);
|
||||
pluginParseResult.setDistance(distance);
|
||||
pluginParseResult.setQueryText(chatParseContext.getQueryText());
|
||||
pluginParseResult.setQueryText(parseContext.getQueryText());
|
||||
properties.put(Constants.CONTEXT, pluginParseResult);
|
||||
properties.put("type", "plugin");
|
||||
properties.put("name", plugin.getName());
|
||||
|
||||
@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQueryResult;
|
||||
@@ -26,25 +26,25 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
|
||||
public boolean checkPreCondition(ChatParseContext chatParseContext) {
|
||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
||||
public boolean checkPreCondition(ParseContext parseContext) {
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
return !CollectionUtils.isEmpty(plugins);
|
||||
}
|
||||
|
||||
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
|
||||
String text = chatParseContext.getQueryText();
|
||||
public PluginRecallResult recallPlugin(ParseContext parseContext) {
|
||||
String text = parseContext.getQueryText();
|
||||
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
|
||||
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
|
||||
return null;
|
||||
}
|
||||
List<ChatPlugin> plugins = getPluginList(chatParseContext);
|
||||
List<ChatPlugin> plugins = getPluginList(parseContext);
|
||||
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
|
||||
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
|
||||
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
|
||||
if (plugin == null) {
|
||||
continue;
|
||||
}
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
|
||||
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, parseContext);
|
||||
log.info("embedding plugin resolve: {}", pair);
|
||||
if (pair.getLeft()) {
|
||||
Set<Long> dataSetList = pair.getRight();
|
||||
@@ -53,7 +53,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
|
||||
}
|
||||
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
|
||||
double distance = embeddingRetrieval.getDistance();
|
||||
double score = chatParseContext.getQueryText().length() * (1 - distance);
|
||||
double score = parseContext.getQueryText().length() * (1 - distance);
|
||||
return PluginRecallResult.builder()
|
||||
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatContext {
|
||||
private Integer chatId;
|
||||
private String queryText;
|
||||
private SemanticParseInfo parseInfo = new SemanticParseInfo();
|
||||
private String user;
|
||||
}
|
||||
@@ -1,17 +1,17 @@
|
||||
package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatExecuteContext {
|
||||
public class ExecuteContext {
|
||||
private User user;
|
||||
private Integer agentId;
|
||||
private Long queryId;
|
||||
private Integer chatId;
|
||||
private int parseId;
|
||||
private String queryText;
|
||||
private Agent agent;
|
||||
private Integer chatId;
|
||||
private Long queryId;
|
||||
private boolean saveAnswer;
|
||||
private SemanticParseInfo parseInfo;
|
||||
}
|
||||
@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ChatParseContext {
|
||||
private String queryText;
|
||||
private Integer chatId;
|
||||
private Agent agent;
|
||||
public class ParseContext {
|
||||
private User user;
|
||||
private String queryText;
|
||||
private Agent agent;
|
||||
private Integer chatId;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo = new SchemaMapInfo();
|
||||
private SchemaMapInfo mapInfo;
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
if (agent == null) {
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
@@ -28,8 +28,8 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int recommend_dimension_size = 5;
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|
||||
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
|
||||
return;
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
|
||||
@@ -9,6 +9,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
*/
|
||||
public interface ExecuteResultProcessor extends ResultProcessor {
|
||||
|
||||
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
|
||||
void process(ExecuteContext executeContext, QueryResult queryResult);
|
||||
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
|
||||
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
|
||||
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
|
||||
import com.tencent.supersonic.common.pojo.QueryColumn;
|
||||
@@ -60,15 +60,15 @@ import org.springframework.util.CollectionUtils;
|
||||
public class MetricRatioProcessor implements ExecuteResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
|
||||
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
|
||||
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|
||||
|| !aggregatorConfig.getEnableRatio()
|
||||
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
|
||||
return;
|
||||
}
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
|
||||
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
|
||||
semanticParseInfo, queryResult);
|
||||
queryResult.setAggregateInfo(aggregateInfo);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.execute;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.DictWordType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
@@ -34,8 +34,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
|
||||
private static final int METRIC_RECOMMEND_SIZE = 5;
|
||||
|
||||
@Override
|
||||
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
|
||||
fillSimilarMetric(chatExecuteContext.getParseInfo());
|
||||
public void process(ExecuteContext executeContext, QueryResult queryResult) {
|
||||
fillSimilarMetric(executeContext.getParseInfo());
|
||||
}
|
||||
|
||||
private void fillSimilarMetric(SemanticParseInfo parseInfo) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
|
||||
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
|
||||
@@ -19,7 +19,7 @@ import java.util.List;
|
||||
public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
|
||||
if (CollectionUtils.isEmpty(selectedParses)) {
|
||||
return;
|
||||
@@ -33,7 +33,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
|
||||
//1. set entity info
|
||||
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
|
||||
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
|
||||
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, parseContext.getUser());
|
||||
if (QueryManager.isTagQuery(queryMode)
|
||||
|| QueryManager.isMetricQuery(queryMode)) {
|
||||
parseInfo.setEntityInfo(entityInfo);
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
|
||||
public interface ParseResultProcessor {
|
||||
|
||||
void process(ChatParseContext chatParseContext, ParseResp parseResp);
|
||||
void process(ParseContext parseContext, ParseResp parseResp);
|
||||
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
@@ -25,15 +25,15 @@ import java.util.stream.Collectors;
|
||||
public class QueryRecommendProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
|
||||
}
|
||||
|
||||
@SneakyThrows
|
||||
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
|
||||
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
|
||||
Long queryId = parseResp.getQueryId();
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
|
||||
chatParseContext.getAgent().getId());
|
||||
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
|
||||
parseContext.getAgent().getId());
|
||||
ChatQueryDO chatQueryDO = getChatQuery(queryId);
|
||||
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
|
||||
updateChatQuery(chatQueryDO);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
public class TimeCostProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
|
||||
public void process(ParseContext parseContext, ParseResp parseResp) {
|
||||
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
|
||||
parseResp.getParseTimeCost().setParseTime(
|
||||
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());
|
||||
|
||||
@@ -6,11 +6,10 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -32,20 +31,20 @@ import javax.validation.Valid;
|
||||
public class ChatQueryController {
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
@PostMapping("search")
|
||||
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
|
||||
HttpServletResponse response) {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.search(chatParseReq);
|
||||
return chatQueryService.search(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("parse")
|
||||
public Object parse(@RequestBody ChatParseReq chatParseReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatParseReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performParsing(chatParseReq);
|
||||
return chatQueryService.performParsing(chatParseReq);
|
||||
}
|
||||
|
||||
@PostMapping("execute")
|
||||
@@ -53,7 +52,7 @@ public class ChatQueryController {
|
||||
HttpServletRequest request, HttpServletResponse response)
|
||||
throws Exception {
|
||||
chatExecuteReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("/")
|
||||
@@ -62,7 +61,7 @@ public class ChatQueryController {
|
||||
throws Exception {
|
||||
User user = UserHolder.findUser(request, response);
|
||||
chatParseReq.setUser(user);
|
||||
ParseResp parseResp = chatService.performParsing(chatParseReq);
|
||||
ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
|
||||
|
||||
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
|
||||
throw new InvalidArgumentException("parser error,no selectedParses");
|
||||
@@ -72,27 +71,20 @@ public class ChatQueryController {
|
||||
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
|
||||
chatExecuteReq.setQueryId(parseResp.getQueryId());
|
||||
chatExecuteReq.setParseId(semanticParseInfo.getId());
|
||||
return chatService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryContext")
|
||||
public Object queryContext(@RequestBody QueryNLReq queryCtx,
|
||||
HttpServletRequest request, HttpServletResponse response) {
|
||||
queryCtx.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryContext(queryCtx.getChatId());
|
||||
return chatQueryService.performExecution(chatExecuteReq);
|
||||
}
|
||||
|
||||
@PostMapping("queryData")
|
||||
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
|
||||
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
@PostMapping("queryDimensionValue")
|
||||
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
|
||||
HttpServletRequest request, HttpServletResponse response) throws Exception {
|
||||
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
|
||||
public interface ChatContextService {
|
||||
|
||||
ChatContext getOrCreateContext(Integer chatId);
|
||||
|
||||
void updateContext(ChatContext chatCtx);
|
||||
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
@@ -12,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ChatService {
|
||||
public interface ChatQueryService {
|
||||
|
||||
List<SearchResult> search(ChatParseReq chatParseReq);
|
||||
|
||||
@@ -24,8 +23,6 @@ public interface ChatService {
|
||||
|
||||
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
|
||||
|
||||
SemanticParseInfo queryContext(Integer chatId);
|
||||
|
||||
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
|
||||
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.service;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
||||
import com.tencent.supersonic.common.config.ChatModelConfig;
|
||||
@@ -36,7 +36,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
private MemoryService memoryService;
|
||||
|
||||
@Autowired
|
||||
private ChatService chatService;
|
||||
private ChatQueryService chatQueryService;
|
||||
|
||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||
|
||||
@@ -103,7 +103,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
|
||||
continue;
|
||||
}
|
||||
try {
|
||||
chatService.parseAndExecute(-1, agent.getId(), example);
|
||||
chatQueryService.parseAndExecute(-1, agent.getId(), example);
|
||||
} catch (Exception e) {
|
||||
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatContextServiceImpl implements ChatContextService {
|
||||
|
||||
private ChatContextRepository chatContextRepository;
|
||||
|
||||
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
|
||||
this.chatContextRepository = chatContextRepository;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ChatContext getOrCreateContext(Integer chatId) {
|
||||
return chatContextRepository.getOrCreateContext(chatId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateContext(ChatContext chatCtx) {
|
||||
log.debug("save ChatContext {}", chatCtx);
|
||||
chatContextRepository.updateContext(chatCtx);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -6,15 +6,16 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.util.ComponentFactory;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
@@ -26,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
|
||||
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -39,49 +40,51 @@ import java.util.List;
|
||||
|
||||
@Slf4j
|
||||
@Service
|
||||
public class ChatServiceImpl implements ChatService {
|
||||
public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
|
||||
@Autowired
|
||||
private ChatManageService chatManageService;
|
||||
@Autowired
|
||||
private ChatQueryService chatQueryService;
|
||||
private ChatLayerService chatLayerService;
|
||||
@Autowired
|
||||
private RetrieveService retrieveService;
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
@Autowired
|
||||
private SemanticLayerService semanticLayerService;
|
||||
@Autowired
|
||||
private ChatContextService chatContextService;
|
||||
|
||||
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
|
||||
|
||||
@Override
|
||||
public List<SearchResult> search(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
return retrieveService.retrieve(queryNLReq);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ParseResp performParsing(ChatParseReq chatParseReq) {
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(chatParseContext);
|
||||
for (ChatParser chatParser : chatParsers) {
|
||||
chatParser.parse(chatParseContext, parseResp);
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(parseContext);
|
||||
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
|
||||
chatQueryParser.parse(parseContext, parseResp);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(chatParseContext, parseResp);
|
||||
processor.process(parseContext, parseResp);
|
||||
}
|
||||
chatParseReq.setQueryText(chatParseContext.getQueryText());
|
||||
parseResp.setQueryText(chatParseContext.getQueryText());
|
||||
chatParseReq.setQueryText(parseContext.getQueryText());
|
||||
parseResp.setQueryText(parseContext.getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
chatManageService.updateParseCostTime(parseResp);
|
||||
return parseResp;
|
||||
@@ -90,9 +93,9 @@ public class ChatServiceImpl implements ChatService {
|
||||
@Override
|
||||
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
|
||||
QueryResult queryResult = new QueryResult();
|
||||
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatExecutor chatExecutor : chatExecutors) {
|
||||
queryResult = chatExecutor.execute(chatExecuteContext);
|
||||
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
|
||||
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
|
||||
queryResult = chatQueryExecutor.execute(executeContext);
|
||||
if (queryResult != null) {
|
||||
break;
|
||||
}
|
||||
@@ -100,7 +103,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
|
||||
if (queryResult != null) {
|
||||
for (ExecuteResultProcessor processor : executeResultProcessors) {
|
||||
processor.process(chatExecuteContext, queryResult);
|
||||
processor.process(executeContext, queryResult);
|
||||
}
|
||||
saveQueryResult(chatExecuteReq, queryResult);
|
||||
}
|
||||
@@ -125,34 +128,36 @@ public class ChatServiceImpl implements ChatService {
|
||||
executeReq.setQueryId(parseResp.getQueryId());
|
||||
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
|
||||
executeReq.setQueryText(queryText);
|
||||
executeReq.setChatId(parseResp.getChatId());
|
||||
executeReq.setChatId(chatId);
|
||||
executeReq.setUser(User.getFakeUser());
|
||||
executeReq.setAgentId(agentId);
|
||||
executeReq.setSaveAnswer(true);
|
||||
return performExecution(executeReq);
|
||||
}
|
||||
|
||||
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ChatParseContext chatParseContext = new ChatParseContext();
|
||||
BeanMapper.mapper(chatParseReq, chatParseContext);
|
||||
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
|
||||
ParseContext parseContext = new ParseContext();
|
||||
BeanMapper.mapper(chatParseReq, parseContext);
|
||||
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
|
||||
chatParseContext.setAgent(agent);
|
||||
return chatParseContext;
|
||||
parseContext.setAgent(agent);
|
||||
return parseContext;
|
||||
}
|
||||
|
||||
private void supplyMapInfo(ChatParseContext chatParseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
|
||||
MapResp mapResp = chatQueryService.performMapping(queryNLReq);
|
||||
chatParseContext.setMapInfo(mapResp.getMapInfo());
|
||||
private void supplyMapInfo(ParseContext parseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp mapResp = chatLayerService.performMapping(queryNLReq);
|
||||
parseContext.setMapInfo(mapResp.getMapInfo());
|
||||
}
|
||||
|
||||
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
|
||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ExecuteContext executeContext = new ExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
|
||||
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
|
||||
chatExecuteContext.setParseInfo(parseInfo);
|
||||
return chatExecuteContext;
|
||||
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
|
||||
executeContext.setAgent(agent);
|
||||
executeContext.setParseInfo(parseInfo);
|
||||
return executeContext;
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -163,12 +168,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
QueryDataReq queryData = new QueryDataReq();
|
||||
BeanMapper.mapper(chatQueryDataReq, queryData);
|
||||
queryData.setParseInfo(parseInfo);
|
||||
return chatQueryService.executeDirectQuery(queryData, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SemanticParseInfo queryContext(Integer chatId) {
|
||||
return chatQueryService.queryContext(chatId);
|
||||
return chatLayerService.executeDirectQuery(queryData, user);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -176,7 +176,7 @@ public class ChatServiceImpl implements ChatService {
|
||||
Integer agentId = dimensionValueReq.getAgentId();
|
||||
Agent agent = agentService.getAgent(agentId);
|
||||
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
|
||||
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
|
||||
return chatLayerService.queryDimensionValue(dimensionValueReq, user);
|
||||
}
|
||||
|
||||
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {
|
||||
@@ -1,8 +1,8 @@
|
||||
package com.tencent.supersonic.chat.server.service.impl;
|
||||
|
||||
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
|
||||
import com.tencent.supersonic.chat.server.service.StatisticsService;
|
||||
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
|
||||
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatParser;
|
||||
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
|
||||
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
|
||||
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
|
||||
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
|
||||
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
|
||||
@@ -16,8 +16,8 @@ import java.util.List;
|
||||
public class ComponentFactory {
|
||||
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
|
||||
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
|
||||
private static List<ChatParser> chatParsers = new ArrayList<>();
|
||||
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
|
||||
private static List<ChatQueryParser> chatQueryParsers = new ArrayList<>();
|
||||
private static List<ChatQueryExecutor> chatQueryExecutors = new ArrayList<>();
|
||||
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
|
||||
|
||||
public static List<ParseResultProcessor> getParseProcessors() {
|
||||
@@ -30,14 +30,14 @@ public class ComponentFactory {
|
||||
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
|
||||
}
|
||||
|
||||
public static List<ChatParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatParsers)
|
||||
? init(ChatParser.class, chatParsers) : chatParsers;
|
||||
public static List<ChatQueryParser> getChatParsers() {
|
||||
return CollectionUtils.isEmpty(chatQueryParsers)
|
||||
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
|
||||
}
|
||||
|
||||
public static List<ChatExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatExecutors)
|
||||
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
|
||||
public static List<ChatQueryExecutor> getChatExecutors() {
|
||||
return CollectionUtils.isEmpty(chatQueryExecutors)
|
||||
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
|
||||
}
|
||||
|
||||
public static List<PluginRecognizer> getPluginRecognizers() {
|
||||
|
||||
@@ -1,20 +1,25 @@
|
||||
package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
|
||||
return buildText2SqlQueryReq(parseContext, null);
|
||||
}
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||
QueryNLReq queryNLReq = new QueryNLReq();
|
||||
BeanMapper.mapper(chatParseContext, queryNLReq);
|
||||
Agent agent = chatParseContext.getAgent();
|
||||
BeanMapper.mapper(parseContext, queryNLReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
if (agent == null) {
|
||||
return queryNLReq;
|
||||
}
|
||||
@@ -39,6 +44,9 @@ public class QueryReqConverter {
|
||||
}
|
||||
queryNLReq.setModelConfig(agent.getModelConfig());
|
||||
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
||||
if (chatCtx != null) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
return queryNLReq;
|
||||
}
|
||||
|
||||
|
||||
30
chat/server/src/main/resources/mapper/ChatContextMapper.xml
Normal file
30
chat/server/src/main/resources/mapper/ChatContextMapper.xml
Normal file
@@ -0,0 +1,30 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
|
||||
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
|
||||
|
||||
<resultMap id="ChatContextDO"
|
||||
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
|
||||
<id column="chat_id" property="chatId"/>
|
||||
<result column="modified_at" property="modifiedAt"/>
|
||||
<result column="user" property="user"/>
|
||||
<result column="query_text" property="queryText"/>
|
||||
<result column="semantic_parse" property="semanticParse"/>
|
||||
<!--<result column="ext_data" property="extData"/>-->
|
||||
</resultMap>
|
||||
|
||||
<select id="getContextByChatId" resultMap="ChatContextDO">
|
||||
select *
|
||||
from s2_chat_context where chat_id=#{chatId} limit 1
|
||||
</select>
|
||||
|
||||
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
|
||||
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
|
||||
</insert>
|
||||
<update id="updateContext">
|
||||
update s2_chat_context set query_text=#{queryText},semantic_parse=#{semanticParse} where chat_id=#{chatId}
|
||||
</update>
|
||||
|
||||
</mapper>
|
||||
28
chat/server/src/main/resources/mapper/StatisticsMapper.xml
Normal file
28
chat/server/src/main/resources/mapper/StatisticsMapper.xml
Normal file
@@ -0,0 +1,28 @@
|
||||
<?xml version="1.0" encoding="UTF-8" ?>
|
||||
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
|
||||
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
|
||||
|
||||
|
||||
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
|
||||
|
||||
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||
<id column="question_id" property="questionId"/>
|
||||
<result column="chat_id" property="chatId"/>
|
||||
<result column="user_name" property="userName"/>
|
||||
<result column="query_text" property="queryText"/>
|
||||
<result column="interface_name" property="interfaceName"/>
|
||||
<result column="cost" property="cost"/>
|
||||
<result column="type" property="type"/>
|
||||
<result column="create_time" property="createTime"/>
|
||||
</resultMap>
|
||||
|
||||
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
|
||||
insert into s2_chat_statistics
|
||||
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
|
||||
values
|
||||
<foreach collection="list" item="item" index="index" separator=",">
|
||||
(#{item.questionId}, #{item.chatId}, #{item.userName}, #{item.queryText}, #{item.interfaceName}, #{item.cost}, #{item.type},#{item.createTime})
|
||||
</foreach>
|
||||
</insert>
|
||||
|
||||
</mapper>
|
||||
Reference in New Issue
Block a user