(improvement)(headless&chat)Move ChatContext from Headless module to Chat module.

This commit is contained in:
jerryjzhang
2024-07-12 16:56:25 +08:00
parent e365a36749
commit baff30550e
78 changed files with 399 additions and 459 deletions

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
public interface ChatExecutor {
QueryResult execute(ChatExecuteContext chatExecuteContext);
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
public interface ChatQueryExecutor {
QueryResult execute(ExecuteContext executeContext);
}

View File

@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -23,7 +23,7 @@ import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
public class PlainTextExecutor implements ChatExecutor {
public class PlainTextExecutor implements ChatQueryExecutor {
private static final String INSTRUCTION = ""
+ "#Role: You are a nice person to talk to.\n"
@@ -34,34 +34,34 @@ public class PlainTextExecutor implements ChatExecutor {
+ "#Your response: ";
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
public QueryResult execute(ExecuteContext executeContext) {
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
return null;
}
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
chatExecuteContext.getQueryText());
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
executeContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult();
result.setQueryState(QueryState.SUCCESS);
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
result.setQueryMode(executeContext.getParseInfo().getQueryMode());
result.setTextResult(response.content().text());
return result;
}
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
private String getHistoryInputs(ExecuteContext executeContext) {
StringBuilder historyInput = new StringBuilder();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
@@ -70,7 +70,7 @@ public class PlainTextExecutor implements ChatExecutor {
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) {
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
List<ParseResp> parseResps = getHistoryParseResult(executeContext.getChatId(), 5);
parseResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");

View File

@@ -2,15 +2,15 @@ package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
public class PluginExecutor implements ChatExecutor {
public class PluginExecutor implements ChatQueryExecutor {
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
public QueryResult execute(ExecuteContext executeContext) {
SemanticParseInfo parseInfo = executeContext.getParseInfo();
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
return null;
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ResultFormatter;
import com.tencent.supersonic.common.pojo.QueryColumn;
@@ -12,10 +12,10 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;
@@ -25,12 +25,12 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
public class SqlExecutor implements ChatExecutor {
public class SqlExecutor implements ChatQueryExecutor {
@SneakyThrows
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
QueryResult queryResult = doExecute(chatExecuteContext);
public QueryResult execute(ExecuteContext executeContext) {
QueryResult queryResult = doExecute(executeContext);
if (queryResult != null) {
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
@@ -41,13 +41,13 @@ public class SqlExecutor implements ChatExecutor {
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
.agentId(chatExecuteContext.getAgentId())
.agentId(executeContext.getAgent().getId())
.status(MemoryStatus.PENDING)
.question(chatExecuteContext.getQueryText())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
.createdBy(chatExecuteContext.getUser().getName())
.updatedBy(chatExecuteContext.getUser().getName())
.question(executeContext.getQueryText())
.s2sql(executeContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(executeContext.getParseInfo()))
.createdBy(executeContext.getUser().getName())
.updatedBy(executeContext.getUser().getName())
.createdAt(new Date())
.build());
}
@@ -57,12 +57,12 @@ public class SqlExecutor implements ChatExecutor {
}
@SneakyThrows
private QueryResult doExecute(ChatExecuteContext chatExecuteContext) {
private QueryResult doExecute(ExecuteContext executeContext) {
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
SemanticParseInfo parseInfo = executeContext.getParseInfo();
if (Objects.isNull(parseInfo.getSqlInfo())
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
return null;
@@ -74,8 +74,10 @@ public class SqlExecutor implements ChatExecutor {
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
sqlReq.setDataSetId(parseInfo.getDataSetId());
long startTime = System.currentTimeMillis();
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, chatExecuteContext.getUser());
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
QueryResult queryResult = new QueryResult();
queryResult.setChatContext(parseInfo);
queryResult.setQueryMode(parseInfo.getQueryMode());
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
@@ -85,7 +87,6 @@ public class SqlExecutor implements ChatExecutor {
queryResult.setQuerySql(queryResp.getSql());
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
chatCtx.setParseInfo(parseInfo);
@@ -94,7 +95,6 @@ public class SqlExecutor implements ChatExecutor {
queryResult.setQueryState(QueryState.INVALID);
queryResult.setQueryMode(parseInfo.getQueryMode());
}
queryResult.setChatContext(chatCtx.getParseInfo());
return queryResult;
}

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ChatParser {
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ChatQueryParser {
void parse(ParseContext parseContext, ParseResp parseResp);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -9,18 +9,18 @@ import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
public class NL2PluginParser implements ChatParser {
public class NL2PluginParser implements ChatQueryParser {
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.getAgent().containsPluginTool()) {
public void parse(ParseContext parseContext, ParseResp parseResp) {
if (!parseContext.getAgent().containsPluginTool()) {
return;
}
pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer.recognize(chatParseContext, parseResp);
pluginRecognizer.recognize(parseContext, parseResp);
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
JsonUtil.toString(parseResp));
});

View File

@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
@@ -17,7 +18,8 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
@@ -42,7 +44,7 @@ import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
@Slf4j
public class NL2SQLParser implements ChatParser {
public class NL2SQLParser implements ChatQueryParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@@ -73,27 +75,30 @@ public class NL2SQLParser implements ChatParser {
+ "#Response: ";
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
public void parse(ParseContext parseContext, ParseResp parseResp) {
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
return;
}
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
chatParseContext.getAgent().getModelConfig());
parseContext.getAgent().getModelConfig());
processMultiTurn(chatLanguageModel, chatParseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
processMultiTurn(chatLanguageModel, parseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
chatParseContext.getQueryText(),
parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(),
chatParseContext.getAgent().getExamples()));
parseContext.getAgent().getExamples()));
}
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
@@ -155,9 +160,9 @@ public class NL2SQLParser implements ChatParser {
parseInfo.setTextInfo(textBuilder.toString());
}
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = agentMultiTurnConfig != null
@@ -167,11 +172,11 @@ public class NL2SQLParser implements ChatParser {
}
// derive mapping result of current question and parsing result of last question.
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
List<ParseResp> historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1);
if (historyParseResults.size() == 0) {
return;
}
@@ -196,7 +201,7 @@ public class NL2SQLParser implements ChatParser {
String rewrittenQuery = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
chatParseContext.setQueryText(rewrittenQuery);
parseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
}

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Parameter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service("ChatParserConfig")
@Service("ChatQueryParserConfig")
@Slf4j
public class ParserConfig extends ParameterConfig {

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatParser {
public class PlainTextParser implements ChatQueryParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (chatParseContext.getAgent().containsAnyTool()) {
public void parse(ParseContext parseContext, ParseResp parseResp) {
if (parseContext.getAgent().containsAnyTool()) {
return;
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
package com.tencent.supersonic.chat.server.persistence.dataobject;
import lombok.Data;

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
package com.tencent.supersonic.chat.server.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.server.persistence.repository;
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
public interface ChatContextRepository {

View File

@@ -1,12 +1,12 @@
package com.tencent.supersonic.headless.server.persistence.repository.impl;
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.google.gson.Gson;
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;

View File

@@ -183,7 +183,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
public List<ParseResp> getContextualParseInfo(Integer chatId) {
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
ParseResp parseResp = new ParseResp(parseInfo.getQueryText());
List<SemanticParseInfo> selectedParses = new ArrayList<>();
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
parseResp.setSelectedParses(selectedParses);

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
@@ -52,9 +52,9 @@ public class PluginManager {
@Autowired
private EmbeddingService embeddingService;
public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
Agent agent = chatParseContext.getAgent();
Agent agent = parseContext.getAgent();
List<ChatPlugin> plugins = pluginService.getPluginList();
if (Objects.isNull(agent)) {
return plugins;
@@ -191,9 +191,9 @@ public class PluginManager {
return String.valueOf(Integer.parseInt(id) / 1000);
}
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
return Pair.of(false, Sets.newHashSet());
}
@@ -259,8 +259,8 @@ public class PluginManager {
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
if (plugin.isContainsAllDataSet()) {
return Sets.newHashSet(plugin.getDefaultMode());
}

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -28,22 +28,22 @@ import java.util.Set;
*/
public abstract class PluginRecognizer {
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!checkPreCondition(chatParseContext)) {
public void recognize(ParseContext parseContext, ParseResp parseResp) {
if (!checkPreCondition(parseContext)) {
return;
}
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
PluginRecallResult pluginRecallResult = recallPlugin(parseContext);
if (pluginRecallResult == null) {
return;
}
buildQuery(chatParseContext, parseResp, pluginRecallResult);
buildQuery(parseContext, parseResp, pluginRecallResult);
}
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
public abstract boolean checkPreCondition(ParseContext parseContext);
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
PluginRecallResult pluginRecallResult) {
ChatPlugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
@@ -52,21 +52,21 @@ public abstract class PluginRecognizer {
}
for (Long dataSetId : dataSetIds) {
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
chatParseContext, pluginRecallResult.getDistance());
parseContext, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(plugin.getType());
semanticParseInfo.setScore(pluginRecallResult.getScore());
parseResp.getSelectedParses().add(semanticParseInfo);
}
}
protected List<ChatPlugin> getPluginList(ChatParseContext chatParseContext) {
return PluginManager.getPluginAgentCanSupport(chatParseContext);
protected List<ChatPlugin> getPluginList(ParseContext parseContext) {
return PluginManager.getPluginAgentCanSupport(parseContext);
}
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
ChatParseContext chatParseContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = chatParseContext.getQueryFilters();
ParseContext parseContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = parseContext.getQueryFilters();
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
@@ -80,7 +80,7 @@ public abstract class PluginRecognizer {
pluginParseResult.setPlugin(plugin);
pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance);
pluginParseResult.setQueryText(chatParseContext.getQueryText());
pluginParseResult.setQueryText(parseContext.getQueryText());
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
@@ -26,25 +26,25 @@ import java.util.stream.Collectors;
@Slf4j
public class EmbeddingRecallRecognizer extends PluginRecognizer {
public boolean checkPreCondition(ChatParseContext chatParseContext) {
List<ChatPlugin> plugins = getPluginList(chatParseContext);
public boolean checkPreCondition(ParseContext parseContext) {
List<ChatPlugin> plugins = getPluginList(parseContext);
return !CollectionUtils.isEmpty(plugins);
}
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
String text = chatParseContext.getQueryText();
public PluginRecallResult recallPlugin(ParseContext parseContext) {
String text = parseContext.getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null;
}
List<ChatPlugin> plugins = getPluginList(chatParseContext);
List<ChatPlugin> plugins = getPluginList(parseContext);
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
continue;
}
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, parseContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
Set<Long> dataSetList = pair.getRight();
@@ -53,7 +53,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = chatParseContext.getQueryText().length() * (1 - distance);
double score = parseContext.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic.headless.chat;
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
@@ -6,7 +6,6 @@ import lombok.Data;
@Data
public class ChatContext {
private Integer chatId;
private String queryText;
private SemanticParseInfo parseInfo = new SemanticParseInfo();

View File

@@ -1,17 +1,17 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
@Data
public class ChatExecuteContext {
public class ExecuteContext {
private User user;
private Integer agentId;
private Long queryId;
private Integer chatId;
private int parseId;
private String queryText;
private Agent agent;
private Integer chatId;
private Long queryId;
private boolean saveAnswer;
private SemanticParseInfo parseInfo;
}

View File

@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data;
@Data
public class ChatParseContext {
private String queryText;
private Integer chatId;
private Agent agent;
public class ParseContext {
private User user;
private String queryText;
private Agent agent;
private Integer chatId;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SchemaMapInfo mapInfo;
public boolean enableNL2SQL() {
if (agent == null) {

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
@@ -28,8 +28,8 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
public void process(ExecuteContext executeContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -9,6 +9,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
*/
public interface ExecuteResultProcessor extends ResultProcessor {
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
void process(ExecuteContext executeContext, QueryResult queryResult);
}

View File

@@ -11,7 +11,7 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.QueryColumn;
@@ -60,15 +60,15 @@ import org.springframework.util.CollectionUtils;
public class MetricRatioProcessor implements ExecuteResultProcessor {
@Override
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
public void process(ExecuteContext executeContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|| !aggregatorConfig.getEnableRatio()
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
return;
}
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.QueryType;
@@ -34,8 +34,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5;
@Override
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
fillSimilarMetric(chatExecuteContext.getParseInfo());
public void process(ExecuteContext executeContext, QueryResult queryResult) {
fillSimilarMetric(executeContext.getParseInfo());
}
private void fillSimilarMetric(SemanticParseInfo parseInfo) {

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
@@ -19,7 +19,7 @@ import java.util.List;
public class EntityInfoProcessor implements ParseResultProcessor {
@Override
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
public void process(ParseContext parseContext, ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
if (CollectionUtils.isEmpty(selectedParses)) {
return;
@@ -33,7 +33,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
//1. set entity info
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, parseContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ParseResultProcessor {
void process(ChatParseContext chatParseContext, ParseResp parseResp);
void process(ParseContext parseContext, ParseResp parseResp);
}

View File

@@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
@@ -25,15 +25,15 @@ import java.util.stream.Collectors;
public class QueryRecommendProcessor implements ParseResultProcessor {
@Override
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
public void process(ParseContext parseContext, ParseResp parseResp) {
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
}
@SneakyThrows
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
chatParseContext.getAgent().getId());
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
parseContext.getAgent().getId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
public class TimeCostProcessor implements ParseResultProcessor {
@Override
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
public void process(ParseContext parseContext, ParseResp parseResp) {
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());

View File

@@ -6,11 +6,10 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
@@ -32,20 +31,20 @@ import javax.validation.Valid;
public class ChatQueryController {
@Autowired
private ChatService chatService;
private ChatQueryService chatQueryService;
@PostMapping("search")
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
HttpServletResponse response) {
chatParseReq.setUser(UserHolder.findUser(request, response));
return chatService.search(chatParseReq);
return chatQueryService.search(chatParseReq);
}
@PostMapping("parse")
public Object parse(@RequestBody ChatParseReq chatParseReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
chatParseReq.setUser(UserHolder.findUser(request, response));
return chatService.performParsing(chatParseReq);
return chatQueryService.performParsing(chatParseReq);
}
@PostMapping("execute")
@@ -53,7 +52,7 @@ public class ChatQueryController {
HttpServletRequest request, HttpServletResponse response)
throws Exception {
chatExecuteReq.setUser(UserHolder.findUser(request, response));
return chatService.performExecution(chatExecuteReq);
return chatQueryService.performExecution(chatExecuteReq);
}
@PostMapping("/")
@@ -62,7 +61,7 @@ public class ChatQueryController {
throws Exception {
User user = UserHolder.findUser(request, response);
chatParseReq.setUser(user);
ParseResp parseResp = chatService.performParsing(chatParseReq);
ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
throw new InvalidArgumentException("parser error,no selectedParses");
@@ -72,27 +71,20 @@ public class ChatQueryController {
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
chatExecuteReq.setQueryId(parseResp.getQueryId());
chatExecuteReq.setParseId(semanticParseInfo.getId());
return chatService.performExecution(chatExecuteReq);
}
@PostMapping("queryContext")
public Object queryContext(@RequestBody QueryNLReq queryCtx,
HttpServletRequest request, HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return chatService.queryContext(queryCtx.getChatId());
return chatQueryService.performExecution(chatExecuteReq);
}
@PostMapping("queryData")
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
}
@PostMapping("queryDimensionValue")
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
}
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
public interface ChatContextService {
ChatContext getOrCreateContext(Integer chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -12,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List;
public interface ChatService {
public interface ChatQueryService {
List<SearchResult> search(ChatParseReq chatParseReq);
@@ -24,8 +23,6 @@ public interface ChatService {
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
SemanticParseInfo queryContext(Integer chatId);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import java.util.List;

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig;
@@ -36,7 +36,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
private MemoryService memoryService;
@Autowired
private ChatService chatService;
private ChatQueryService chatQueryService;
private ExecutorService executorService = Executors.newFixedThreadPool(1);
@@ -103,7 +103,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
continue;
}
try {
chatService.parseAndExecute(-1, agent.getId(), example);
chatQueryService.parseAndExecute(-1, agent.getId(), example);
} catch (Exception e) {
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
}

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ChatContextServiceImpl implements ChatContextService {
private ChatContextRepository chatContextRepository;
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
this.chatContextRepository = chatContextRepository;
}
@Override
public ChatContext getOrCreateContext(Integer chatId) {
return chatContextRepository.getOrCreateContext(chatId);
}
@Override
public void updateContext(ChatContext chatCtx) {
log.debug("save ChatContext {}", chatCtx);
chatContextRepository.updateContext(chatCtx);
}
}

View File

@@ -6,15 +6,16 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
import com.tencent.supersonic.chat.server.parser.ChatParser;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.BeanMapper;
@@ -26,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
@@ -39,49 +40,51 @@ import java.util.List;
@Slf4j
@Service
public class ChatServiceImpl implements ChatService {
public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private ChatManageService chatManageService;
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@Autowired
private RetrieveService retrieveService;
@Autowired
private AgentService agentService;
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private ChatContextService chatContextService;
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
@Override
public List<SearchResult> search(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
Agent agent = chatParseContext.getAgent();
ParseContext parseContext = buildParseContext(chatParseReq);
Agent agent = parseContext.getAgent();
if (!agent.enableSearch()) {
return Lists.newArrayList();
}
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
return retrieveService.retrieve(queryNLReq);
}
@Override
public ParseResp performParsing(ChatParseReq chatParseReq) {
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp);
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
supplyMapInfo(chatParseContext);
for (ChatParser chatParser : chatParsers) {
chatParser.parse(chatParseContext, parseResp);
ParseContext parseContext = buildParseContext(chatParseReq);
supplyMapInfo(parseContext);
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
chatQueryParser.parse(parseContext, parseResp);
}
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(chatParseContext, parseResp);
processor.process(parseContext, parseResp);
}
chatParseReq.setQueryText(chatParseContext.getQueryText());
parseResp.setQueryText(chatParseContext.getQueryText());
chatParseReq.setQueryText(parseContext.getQueryText());
parseResp.setQueryText(parseContext.getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp);
chatManageService.updateParseCostTime(parseResp);
return parseResp;
@@ -90,9 +93,9 @@ public class ChatServiceImpl implements ChatService {
@Override
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
QueryResult queryResult = new QueryResult();
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
for (ChatExecutor chatExecutor : chatExecutors) {
queryResult = chatExecutor.execute(chatExecuteContext);
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
queryResult = chatQueryExecutor.execute(executeContext);
if (queryResult != null) {
break;
}
@@ -100,7 +103,7 @@ public class ChatServiceImpl implements ChatService {
if (queryResult != null) {
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(chatExecuteContext, queryResult);
processor.process(executeContext, queryResult);
}
saveQueryResult(chatExecuteReq, queryResult);
}
@@ -125,34 +128,36 @@ public class ChatServiceImpl implements ChatService {
executeReq.setQueryId(parseResp.getQueryId());
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setQueryText(queryText);
executeReq.setChatId(parseResp.getChatId());
executeReq.setChatId(chatId);
executeReq.setUser(User.getFakeUser());
executeReq.setAgentId(agentId);
executeReq.setSaveAnswer(true);
return performExecution(executeReq);
}
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = new ChatParseContext();
BeanMapper.mapper(chatParseReq, chatParseContext);
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
ParseContext parseContext = new ParseContext();
BeanMapper.mapper(chatParseReq, parseContext);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
chatParseContext.setAgent(agent);
return chatParseContext;
parseContext.setAgent(agent);
return parseContext;
}
private void supplyMapInfo(ChatParseContext chatParseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp mapResp = chatQueryService.performMapping(queryNLReq);
chatParseContext.setMapInfo(mapResp.getMapInfo());
private void supplyMapInfo(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp mapResp = chatLayerService.performMapping(queryNLReq);
parseContext.setMapInfo(mapResp.getMapInfo());
}
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ExecuteContext executeContext = new ExecuteContext();
BeanMapper.mapper(chatExecuteReq, executeContext);
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
chatExecuteContext.setParseInfo(parseInfo);
return chatExecuteContext;
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
executeContext.setAgent(agent);
executeContext.setParseInfo(parseInfo);
return executeContext;
}
@Override
@@ -163,12 +168,7 @@ public class ChatServiceImpl implements ChatService {
QueryDataReq queryData = new QueryDataReq();
BeanMapper.mapper(chatQueryDataReq, queryData);
queryData.setParseInfo(parseInfo);
return chatQueryService.executeDirectQuery(queryData, user);
}
@Override
public SemanticParseInfo queryContext(Integer chatId) {
return chatQueryService.queryContext(chatId);
return chatLayerService.executeDirectQuery(queryData, user);
}
@Override
@@ -176,7 +176,7 @@ public class ChatServiceImpl implements ChatService {
Integer agentId = dimensionValueReq.getAgentId();
Agent agent = agentService.getAgent(agentId);
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
return chatLayerService.queryDimensionValue(dimensionValueReq, user);
}
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.service.StatisticsService;
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
import com.tencent.supersonic.chat.server.parser.ChatParser;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
@@ -16,8 +16,8 @@ import java.util.List;
public class ComponentFactory {
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
private static List<ChatParser> chatParsers = new ArrayList<>();
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
private static List<ChatQueryParser> chatQueryParsers = new ArrayList<>();
private static List<ChatQueryExecutor> chatQueryExecutors = new ArrayList<>();
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
public static List<ParseResultProcessor> getParseProcessors() {
@@ -30,14 +30,14 @@ public class ComponentFactory {
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
}
public static List<ChatParser> getChatParsers() {
return CollectionUtils.isEmpty(chatParsers)
? init(ChatParser.class, chatParsers) : chatParsers;
public static List<ChatQueryParser> getChatParsers() {
return CollectionUtils.isEmpty(chatQueryParsers)
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
}
public static List<ChatExecutor> getChatExecutors() {
return CollectionUtils.isEmpty(chatExecutors)
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
public static List<ChatQueryExecutor> getChatExecutors() {
return CollectionUtils.isEmpty(chatQueryExecutors)
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
}
public static List<PluginRecognizer> getPluginRecognizers() {

View File

@@ -1,20 +1,25 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import org.apache.commons.collections.MapUtils;
import java.util.Objects;
public class QueryReqConverter {
public static QueryNLReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
return buildText2SqlQueryReq(parseContext, null);
}
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(chatParseContext, queryNLReq);
Agent agent = chatParseContext.getAgent();
BeanMapper.mapper(parseContext, queryNLReq);
Agent agent = parseContext.getAgent();
if (agent == null) {
return queryNLReq;
}
@@ -39,6 +44,9 @@ public class QueryReqConverter {
}
queryNLReq.setModelConfig(agent.getModelConfig());
queryNLReq.setPromptConfig(agent.getPromptConfig());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
return queryNLReq;
}

View File

@@ -3,10 +3,10 @@
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.ChatContextMapper">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
<resultMap id="ChatContextDO"
type="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO">
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
<id column="chat_id" property="chatId"/>
<result column="modified_at" property="modifiedAt"/>
<result column="user" property="user"/>
@@ -20,7 +20,7 @@
from s2_chat_context where chat_id=#{chatId} limit 1
</select>
<insert id="addContext" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.ChatContextDO" >
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
</insert>
<update id="updateContext">

View File

@@ -3,9 +3,9 @@
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
<resultMap id="Statistics" type="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
<id column="question_id" property="questionId"/>
<result column="chat_id" property="chatId"/>
<result column="user_name" property="userName"/>
@@ -16,7 +16,7 @@
<result column="create_time" property="createTime"/>
</resultMap>
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO">
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
insert into s2_chat_statistics
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
values

View File

@@ -11,7 +11,6 @@ import lombok.Data;
public class ExecuteQueryReq {
private User user;
private Long queryId;
private Integer chatId;
private String queryText;
private SemanticParseInfo parseInfo;
private boolean saveAnswer;

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data;
@@ -18,7 +19,6 @@ import java.util.Set;
@Data
public class QueryNLReq {
private String queryText;
private Integer chatId;
private Set<Long> dataSetIds = Sets.newHashSet();
private User user;
private QueryFilters queryFilters;
@@ -30,4 +30,5 @@ public class QueryNLReq {
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
private SemanticParseInfo contextParseInfo;
}

View File

@@ -10,7 +10,6 @@ import java.util.stream.Collectors;
@Data
public class ParseResp {
private Integer chatId;
private String queryText;
private Long queryId;
private ParseState state = ParseState.PENDING;
@@ -24,8 +23,7 @@ public class ParseResp {
FAILED
}
public ParseResp(Integer chatId, String queryText) {
this.chatId = chatId;
public ParseResp(String queryText) {
this.queryText = queryText;
parseTimeCost.setParseStartTime(System.currentTimeMillis());
}

View File

@@ -8,6 +8,7 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
@@ -35,7 +36,6 @@ import java.util.stream.Collectors;
public class ChatQueryContext {
private String queryText;
private Integer chatId;
private Set<Long> dataSetIds;
private Map<Long, List<Long>> modelIdToDataSetIds;
private User user;
@@ -54,6 +54,7 @@ public class ChatQueryContext {
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;
private SemanticParseInfo contextParseInfo;
public List<SemanticQuery> getCandidateQueries() {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);

View File

@@ -11,7 +11,6 @@ import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -29,7 +28,7 @@ import java.util.stream.Collectors;
public class QueryTypeParser implements SemanticParser {
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
User user = chatQueryContext.getUser();

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless.chat.parser;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
/**
@@ -10,5 +9,5 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
*/
public interface SemanticParser {
void parse(ChatQueryContext chatQueryContext, ChatContext chatContext);
void parse(ChatQueryContext chatQueryContext);
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
@@ -23,7 +22,7 @@ import org.apache.commons.collections.MapUtils;
public class LLMSqlParser implements SemanticParser {
@Override
public void parse(ChatQueryContext queryCtx, ChatContext chatCtx) {
public void parse(ChatQueryContext queryCtx) {
try {
LLMRequestService requestService = ContextUtils.getBean(LLMRequestService.class);
//1.determine whether to skip this parser.

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.rule;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -41,7 +40,7 @@ public class AggregateTypeParser implements SemanticParser {
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2));
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
String queryText = chatQueryContext.getQueryText();
AggregateConf aggregateConf = resolveAggregateConf(queryText);

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery;
@@ -43,11 +42,11 @@ public class ContextInheritParser implements SemanticParser {
).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
if (!shouldInherit(chatQueryContext)) {
return;
}
Long dataSetId = getMatchedDataSet(chatQueryContext, chatContext);
Long dataSetId = getMatchedDataSet(chatQueryContext);
if (dataSetId == null) {
return;
}
@@ -55,10 +54,11 @@ public class ContextInheritParser implements SemanticParser {
List<SchemaElementMatch> elementMatches = chatQueryContext.getMapInfo().getMatchedElements(dataSetId);
List<SchemaElementMatch> matchesToInherit = new ArrayList<>();
for (SchemaElementMatch match : chatContext.getParseInfo().getElementMatches()) {
for (SchemaElementMatch match : chatQueryContext.getContextParseInfo().getElementMatches()) {
SchemaElementType matchType = match.getElement().getType();
// mutual exclusive element types should not be inherited
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(chatContext.getParseInfo().getQueryMode());
RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery(
chatQueryContext.getContextParseInfo().getQueryMode());
if (!containsTypes(elementMatches, matchType, ruleQuery)) {
match.setInherited(true);
matchesToInherit.add(match);
@@ -68,7 +68,7 @@ public class ContextInheritParser implements SemanticParser {
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(chatQueryContext, chatContext);
query.fillParseInfo(chatQueryContext);
if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) {
continue;
}
@@ -108,8 +108,8 @@ public class ContextInheritParser implements SemanticParser {
return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size();
}
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext, ChatContext chatContext) {
Long dataSetId = chatContext.getParseInfo().getDataSetId();
protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) {
Long dataSetId = chatQueryContext.getContextParseInfo().getDataSetId();
if (dataSetId == null) {
return null;
}

View File

@@ -5,7 +5,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.ChatContext;
import lombok.extern.slf4j.Slf4j;
import java.util.Arrays;
import java.util.List;
@@ -24,7 +23,7 @@ public class RuleSqlParser implements SemanticParser {
);
@Override
public void parse(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void parse(ChatQueryContext chatQueryContext) {
if (!chatQueryContext.getText2SQLType().enableRule()) {
return;
}
@@ -34,11 +33,11 @@ public class RuleSqlParser implements SemanticParser {
List<SchemaElementMatch> elementMatches = mapInfo.getMatchedElements(dataSetId);
List<RuleSemanticQuery> queries = RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext);
for (RuleSemanticQuery query : queries) {
query.fillParseInfo(chatQueryContext, chatContext);
query.fillParseInfo(chatQueryContext);
chatQueryContext.getCandidateQueries().add(query);
}
}
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext, chatContext));
auxiliaryParsers.stream().forEach(p -> p.parse(chatQueryContext));
}
}

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.headless.chat.parser.rule;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser;
import com.tencent.supersonic.headless.chat.query.QueryManager;
@@ -42,7 +41,7 @@ public class TimeRangeParser implements SemanticParser {
private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd");
@Override
public void parse(ChatQueryContext queryContext, ChatContext chatContext) {
public void parse(ChatQueryContext queryContext) {
String queryText = queryContext.getQueryText();
DateConf dateConf = parseRecent(queryText);
if (dateConf == null) {
@@ -59,14 +58,14 @@ public class TimeRangeParser implements SemanticParser {
query.getParseInfo().setScore(query.getParseInfo().getScore()
+ dateConf.getDetectWord().length());
}
} else if (QueryManager.containsRuleQuery(chatContext.getParseInfo().getQueryMode())) {
} else if (QueryManager.containsRuleQuery(queryContext.getContextParseInfo().getQueryMode())) {
RuleSemanticQuery semanticQuery = QueryManager.createRuleQuery(
chatContext.getParseInfo().getQueryMode());
queryContext.getContextParseInfo().getQueryMode());
// inherit parse info from context
chatContext.getParseInfo().setDateInfo(dateConf);
chatContext.getParseInfo().setScore(chatContext.getParseInfo().getScore()
queryContext.getContextParseInfo().setDateInfo(dateConf);
queryContext.getContextParseInfo().setScore(queryContext.getContextParseInfo().getScore()
+ dateConf.getDetectWord().length());
semanticQuery.setParseInfo(chatContext.getParseInfo());
semanticQuery.setParseInfo(queryContext.getContextParseInfo());
queryContext.getCandidateQueries().add(semanticQuery);
}
}

View File

@@ -16,7 +16,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.utils.QueryReqBuilder;
import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.ChatContext;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -49,13 +48,13 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery {
initS2SqlByStruct(semanticSchema);
}
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void fillParseInfo(ChatQueryContext chatQueryContext) {
parseInfo.setQueryMode(getQueryMode());
SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema();
fillSchemaElement(parseInfo, semanticSchema);
fillScore(parseInfo);
fillDateConf(parseInfo, chatContext.getParseInfo());
fillDateConf(parseInfo, chatQueryContext.getContextParseInfo());
}
private void fillDateConf(SemanticParseInfo queryParseInfo, SemanticParseInfo chatParseInfo) {

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import org.apache.commons.collections.CollectionUtils;
@@ -19,8 +18,8 @@ import java.util.stream.Collectors;
public abstract class DetailListQuery extends DetailSemanticQuery {
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo);
}

View File

@@ -10,7 +10,6 @@ import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.QueryMatchOption;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.ChatContext;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
@@ -35,8 +34,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
parseInfo.setQueryType(QueryType.DETAIL);
parseInfo.setLimit(DETAIL_MAX_RESULTS);

View File

@@ -8,7 +8,6 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery;
import com.tencent.supersonic.headless.chat.ChatContext;
import lombok.extern.slf4j.Slf4j;
import java.time.LocalDate;
@@ -36,8 +35,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
parseInfo.setLimit(METRIC_MAX_RESULTS);
if (parseInfo.getDateInfo() == null) {
DataSetSchema dataSetSchema =

View File

@@ -6,7 +6,6 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.ChatContext;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
@@ -50,8 +49,8 @@ public class MetricTopNQuery extends MetricSemanticQuery {
}
@Override
public void fillParseInfo(ChatQueryContext chatQueryContext, ChatContext chatContext) {
super.fillParseInfo(chatQueryContext, chatContext);
public void fillParseInfo(ChatQueryContext chatQueryContext) {
super.fillParseInfo(chatQueryContext);
parseInfo.setLimit(ORDERBY_MAX_RESULTS);
parseInfo.setScore(parseInfo.getScore() + 2.0);

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
@@ -24,7 +24,7 @@ import javax.servlet.http.HttpServletResponse;
public class ChatQueryApiController {
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@Autowired
private RetrieveService retrieveService;
@@ -45,7 +45,7 @@ public class ChatQueryApiController {
HttpServletRequest request,
HttpServletResponse response) {
queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performMapping(queryNLReq);
return chatLayerService.performMapping(queryNLReq);
}
@PostMapping("/chat/parse")
@@ -53,7 +53,7 @@ public class ChatQueryApiController {
HttpServletRequest request,
HttpServletResponse response) throws Exception {
queryNLReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performParsing(queryNLReq);
return chatLayerService.performParsing(queryNLReq);
}
@PostMapping("/chat")
@@ -61,7 +61,7 @@ public class ChatQueryApiController {
HttpServletRequest request,
HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
ParseResp parseResp = chatQueryService.performParsing(queryNLReq);
ParseResp parseResp = chatLayerService.performParsing(queryNLReq);
if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) {
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
QuerySqlReq sqlReq = new QuerySqlReq();

View File

@@ -3,7 +3,7 @@ package com.tencent.supersonic.headless.server.facade.rest;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.PostMapping;
@@ -20,14 +20,14 @@ import javax.servlet.http.HttpServletResponse;
public class MetaDiscoveryApiController {
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@PostMapping("map")
public Object map(@RequestBody QueryMapReq queryMapReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
queryMapReq.setUser(user);
return chatQueryService.map(queryMapReq);
return chatLayerService.map(queryMapReq);
}
}

View File

@@ -7,7 +7,7 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QuerySqlsReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
@@ -33,7 +33,7 @@ public class SqlQueryApiController {
private SemanticLayerService semanticLayerService;
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@PostMapping("/sql")
public Object queryBySql(@RequestBody QuerySqlReq querySqlReq,
@@ -42,7 +42,7 @@ public class SqlQueryApiController {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
chatLayerService.correct(querySqlReq, user);
return semanticLayerService.queryByReq(querySqlReq, user);
}
@@ -56,7 +56,7 @@ public class SqlQueryApiController {
QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
chatLayerService.correct(querySqlReq, user);
return querySqlReq;
}).collect(Collectors.toList());
@@ -82,7 +82,7 @@ public class SqlQueryApiController {
QuerySqlReq querySqlReq = new QuerySqlReq();
BeanUtils.copyProperties(querySqlsReq, querySqlReq);
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
chatQueryService.correct(querySqlReq, user);
chatLayerService.correct(querySqlReq, user);
return querySqlReq;
}).collect(Collectors.toList());
List<SemanticQueryResp> semanticQueryRespList = new ArrayList<>();
@@ -104,7 +104,7 @@ public class SqlQueryApiController {
User user = UserHolder.findUser(request, response);
String sql = querySqlReq.getSql();
querySqlReq.setSql(StringUtil.replaceBackticks(sql));
return chatQueryService.validate(querySqlReq, user);
return chatLayerService.validate(querySqlReq, user);
}
}

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.server.facade.service;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.SqlEvaluation;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryDataReq;
@@ -16,14 +15,12 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
/***dd
* SemanticLayerService for query and search
*/
public interface ChatQueryService {
public interface ChatLayerService {
MapResp performMapping(QueryNLReq queryNLReq);
ParseResp performParsing(QueryNLReq queryNLReq);
SemanticParseInfo queryContext(Integer chatId);
QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws Exception;
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;

View File

@@ -22,7 +22,6 @@ import com.tencent.supersonic.headless.chat.mapper.MatchText;
import com.tencent.supersonic.headless.chat.mapper.ModelWithSemanticType;
import com.tencent.supersonic.headless.chat.mapper.SearchMatchStrategy;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
@@ -52,9 +51,6 @@ public class RetrieveServiceImpl implements RetrieveService {
@Autowired
private DataSetService dataSetService;
@Autowired
private ChatContextService chatContextService;
@Autowired
private SchemaService schemaService;
@@ -135,7 +131,7 @@ public class RetrieveServiceImpl implements RetrieveService {
List<Long> possibleDataSets = NatureHelper.selectPossibleDataSets(originals);
Long contextDataset = chatContextService.getContextDataset(queryCtx.getChatId());
Long contextDataset = queryCtx.getContextParseInfo().getDataSetId();
log.debug("possibleDataSets:{},dataSetInfoStat:{},contextDataset:{}",
possibleDataSets, dataSetInfoStat, contextDataset);

View File

@@ -44,7 +44,6 @@ import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.GrammarCorrector;
import com.tencent.supersonic.headless.chat.corrector.SchemaCorrector;
@@ -57,12 +56,11 @@ import com.tencent.supersonic.headless.chat.knowledge.helper.NatureHelper;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.utils.ChatWorkflowEngine;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.utils.ComponentFactory;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.headless.server.web.service.DataSetService;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
import lombok.extern.slf4j.Slf4j;
@@ -98,14 +96,12 @@ import java.util.stream.Collectors;
@Service
@Slf4j
public class ChatQueryServiceImpl implements ChatQueryService {
public class S2ChatLayerService implements ChatLayerService {
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private SchemaService schemaService;
@Autowired
private ChatContextService chatContextService;
@Autowired
private KnowledgeBaseService knowledgeBaseService;
@Autowired
private DataSetService dataSetService;
@@ -141,14 +137,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Override
public ParseResp performParsing(QueryNLReq queryNLReq) {
ParseResp parseResult = new ParseResp(queryNLReq.getChatId(), queryNLReq.getQueryText());
// build queryContext and chatContext
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
// build queryContext
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
// in order to support multi-turn conversation, chat context is needed
ChatContext chatCtx = chatContextService.getOrCreateContext(queryNLReq.getChatId());
chatWorkflowEngine.execute(queryCtx, chatCtx, parseResult);
chatWorkflowEngine.execute(queryCtx, parseResult);
List<SemanticParseInfo> parseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
@@ -173,12 +166,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryCtx;
}
@Override
public SemanticParseInfo queryContext(Integer chatId) {
ChatContext context = chatContextService.getOrCreateContext(chatId);
return context.getParseInfo();
}
//mainly used for executing after revising filters,for example:"fans_cnt>=100000"->"fans_cnt>500000",
//"style='流行'"->"style in ['流行','爱国']"
@Override

View File

@@ -14,7 +14,6 @@ import com.tencent.supersonic.headless.api.pojo.SemanticSchema;
import com.tencent.supersonic.headless.api.pojo.SqlInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.query.SemanticQuery;
import com.tencent.supersonic.headless.server.web.service.SchemaService;
@@ -40,7 +39,7 @@ import java.util.stream.Collectors;
public class ParseInfoProcessor implements ResultProcessor {
@Override
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext) {
public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) {
List<SemanticQuery> candidateQueries = chatQueryContext.getCandidateQueries();
if (CollectionUtils.isEmpty(candidateQueries)) {
return;

View File

@@ -1,7 +1,6 @@
package com.tencent.supersonic.headless.server.processor;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
/**
@@ -9,6 +8,6 @@ import com.tencent.supersonic.headless.chat.ChatQueryContext;
*/
public interface ResultProcessor {
void process(ParseResp parseResp, ChatQueryContext chatQueryContext, ChatContext chatContext);
void process(ParseResp parseResp, ChatQueryContext chatQueryContext);
}

View File

@@ -7,7 +7,6 @@ import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.corrector.SemanticCorrector;
import com.tencent.supersonic.headless.chat.mapper.SchemaMapper;
@@ -39,7 +38,7 @@ public class ChatWorkflowEngine {
private List<SemanticCorrector> semanticCorrectors = ComponentFactory.getSemanticCorrectors();
private List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void execute(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
switch (queryCtx.getChatWorkflowState()) {
@@ -54,7 +53,7 @@ public class ChatWorkflowEngine {
}
break;
case PARSING:
performParsing(queryCtx, chatCtx);
performParsing(queryCtx);
if (queryCtx.getCandidateQueries().size() == 0) {
parseResult.setState(ParseResp.ParseState.FAILED);
parseResult.setErrorMsg("No semantic queries can be parsed out.");
@@ -75,7 +74,7 @@ public class ChatWorkflowEngine {
break;
case PROCESSING:
default:
performProcessing(queryCtx, chatCtx, parseResult);
performProcessing(queryCtx, parseResult);
if (parseResult.getState().equals(ParseResp.ParseState.PENDING)) {
parseResult.setState(ParseResp.ParseState.COMPLETED);
}
@@ -92,9 +91,9 @@ public class ChatWorkflowEngine {
}
}
private void performParsing(ChatQueryContext queryCtx, ChatContext chatCtx) {
private void performParsing(ChatQueryContext queryCtx) {
semanticParsers.forEach(parser -> {
parser.parse(queryCtx, chatCtx);
parser.parse(queryCtx);
log.debug("{} result:{}", parser.getClass().getSimpleName(), JsonUtil.toString(queryCtx));
});
}
@@ -116,9 +115,9 @@ public class ChatWorkflowEngine {
}
}
private void performProcessing(ChatQueryContext queryCtx, ChatContext chatCtx, ParseResp parseResult) {
private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) {
resultProcessors.forEach(processor -> {
processor.process(parseResult, queryCtx, chatCtx);
processor.process(parseResult, queryCtx);
});
}

View File

@@ -1,18 +0,0 @@
package com.tencent.supersonic.headless.server.web.service;
import com.tencent.supersonic.headless.chat.ChatContext;
public interface ChatContextService {
/***
* get the model from context
* @param chatId
* @return
*/
Long getContextDataset(Integer chatId);
ChatContext getOrCreateContext(Integer chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -1,50 +0,0 @@
package com.tencent.supersonic.headless.server.web.service.impl;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.headless.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.Objects;
@Slf4j
@Service
public class ChatContextServiceImpl implements ChatContextService {
private ChatContextRepository chatContextRepository;
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
this.chatContextRepository = chatContextRepository;
}
@Override
public Long getContextDataset(Integer chatId) {
if (Objects.isNull(chatId)) {
return null;
}
ChatContext chatContext = getOrCreateContext(chatId);
if (Objects.isNull(chatContext)) {
return null;
}
SemanticParseInfo originalSemanticParse = chatContext.getParseInfo();
if (Objects.nonNull(originalSemanticParse) && Objects.nonNull(originalSemanticParse.getDataSetId())) {
return originalSemanticParse.getDataSetId();
}
return null;
}
@Override
public ChatContext getOrCreateContext(Integer chatId) {
return chatContextRepository.getOrCreateContext(chatId);
}
@Override
public void updateContext(ChatContext chatCtx) {
log.debug("save ChatContext {}", chatCtx);
chatContextRepository.updateContext(chatCtx);
}
}

View File

@@ -44,7 +44,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.TagItem;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.CollectDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQueryDefaultConfigDO;
@@ -111,7 +111,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
private TagMetaService tagMetaService;
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
public MetricServiceImpl(MetricRepository metricRepository,
ModelService modelService,
@@ -121,7 +121,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
ApplicationEventPublisher eventPublisher,
DimensionService dimensionService,
TagMetaService tagMetaService,
@Lazy ChatQueryService chatQueryService) {
@Lazy ChatLayerService chatLayerService) {
this.metricRepository = metricRepository;
this.modelService = modelService;
this.aliasGenerateHelper = aliasGenerateHelper;
@@ -130,7 +130,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
this.dataSetService = dataSetService;
this.dimensionService = dimensionService;
this.tagMetaService = tagMetaService;
this.chatQueryService = chatQueryService;
this.chatLayerService = chatLayerService;
}
@Override
@@ -299,7 +299,7 @@ public class MetricServiceImpl extends ServiceImpl<MetricDOMapper, MetricDO>
queryMapReq.setQueryText(pageMetricReq.getKey());
queryMapReq.setUser(user);
queryMapReq.setMapModeEnum(MapModeEnum.LOOSE);
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
Map<String, DataSetMapInfo> dataSetMapInfoMap = mapMeta.getDataSetMapInfo();
if (CollectionUtils.isEmpty(dataSetMapInfoMap)) {
return metricRespPageInfo;

View File

@@ -16,7 +16,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MetricType;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO;
import com.tencent.supersonic.headless.server.persistence.repository.MetricRepository;
import com.tencent.supersonic.headless.server.utils.AliasGenerateHelper;
@@ -77,10 +77,10 @@ public class MetricServiceImplTest {
DataSetService dataSetService = Mockito.mock(DataSetServiceImpl.class);
DimensionService dimensionService = Mockito.mock(DimensionService.class);
TagMetaService tagMetaService = Mockito.mock(TagMetaService.class);
ChatQueryService chatQueryService = Mockito.mock(ChatQueryService.class);
ChatLayerService chatLayerService = Mockito.mock(ChatLayerService.class);
return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper,
collectService, dataSetService, eventPublisher, dimensionService,
tagMetaService, chatQueryService);
tagMetaService, chatLayerService);
}
private MetricReq buildMetricReq() {

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
@@ -75,7 +75,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
@Autowired
protected TagObjectService tagObjectService;
@Autowired
protected ChatService chatService;
protected ChatQueryService chatQueryService;
@Autowired
protected ChatManageService chatManageService;
@Autowired

View File

@@ -137,11 +137,11 @@ public class S2VisitsDemo extends S2BaseDemo {
public void addSampleChats(Integer agentId) {
Long chatId = chatManageService.addChat(user, "样例对话1", agentId);
chatService.parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数");
chatService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
chatService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
chatService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
chatService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数");
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "按部门统计");
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "查询近30天");
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
chatQueryService.parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
}
private Integer addAgent(long dataSetId) {

View File

@@ -51,12 +51,12 @@ com.tencent.supersonic.headless.server.processor.ResultProcessor=\
### chat-server SPIs
com.tencent.supersonic.chat.server.parser.ChatParser=\
com.tencent.supersonic.chat.server.parser.ChatQueryParser=\
com.tencent.supersonic.chat.server.parser.NL2PluginParser, \
com.tencent.supersonic.chat.server.parser.NL2SQLParser,\
com.tencent.supersonic.chat.server.parser.PlainTextParser
com.tencent.supersonic.chat.server.executor.ChatExecutor=\
com.tencent.supersonic.chat.server.executor.ChatQueryExecutor=\
com.tencent.supersonic.chat.server.executor.PluginExecutor, \
com.tencent.supersonic.chat.server.executor.SqlExecutor,\
com.tencent.supersonic.chat.server.executor.PlainTextExecutor

View File

@@ -4,7 +4,7 @@ import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -28,7 +28,7 @@ public class BaseTest extends BaseApplication {
protected final String period = "DAY";
@Autowired
protected ChatService chatService;
protected ChatQueryService chatQueryService;
@Autowired
protected AgentService agentService;
@@ -37,33 +37,34 @@ public class BaseTest extends BaseApplication {
SemanticParseInfo semanticParseInfo = parseResp.getSelectedParses().get(0);
ChatExecuteReq request = ChatExecuteReq.builder()
.chatId(parseResp.getChatId())
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(semanticParseInfo.getId())
.queryId(parseResp.getQueryId())
.chatId(chatId)
.saveAnswer(true)
.build();
QueryResult queryResult = chatService.performExecution(request);
QueryResult queryResult = chatQueryService.performExecution(request);
queryResult.setChatContext(semanticParseInfo);
return queryResult;
}
protected QueryResult submitNewChat(String queryText, Integer agentId) throws Exception {
ParseResp parseResp = submitParse(queryText, agentId);
int chatId = 10;
ParseResp parseResp = submitParse(queryText, agentId, chatId);
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
ChatExecuteReq request = ChatExecuteReq.builder()
.chatId(parseResp.getChatId())
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(parseInfo.getId())
.agentId(agentId)
.chatId(chatId)
.queryId(parseResp.getQueryId())
.saveAnswer(false)
.build();
QueryResult result = chatService.performExecution(request);
QueryResult result = chatQueryService.performExecution(request);
result.setChatContext(parseInfo);
return result;
}
@@ -74,7 +75,7 @@ public class BaseTest extends BaseApplication {
}
ChatParseReq chatParseReq = DataUtils.getChatParseReq(chatId, queryText);
chatParseReq.setAgentId(agentId);
return chatService.performParsing(chatParseReq);
return chatQueryService.performParsing(chatParseReq);
}
protected ParseResp submitParse(String queryText, Integer agentId) {
@@ -83,7 +84,7 @@ public class BaseTest extends BaseApplication {
protected ParseResp submitParseWithAgent(String queryText, Integer agentId) {
ChatParseReq chatParseReq = DataUtils.getChatParseReqWithAgent(10, queryText, agentId);
return chatService.performParsing(chatParseReq);
return chatQueryService.performParsing(chatParseReq);
}
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {

View File

@@ -4,7 +4,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
import com.tencent.supersonic.headless.api.pojo.request.QueryMapReq;
import com.tencent.supersonic.headless.api.pojo.response.MapInfoResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import org.junit.Assert;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -15,7 +15,7 @@ import java.util.Collections;
public class MetaDiscoveryTest extends BaseTest {
@Autowired
protected ChatQueryService chatQueryService;
protected ChatLayerService chatLayerService;
@Test
public void testGetMapMeta() throws Exception {
@@ -24,7 +24,7 @@ public class MetaDiscoveryTest extends BaseTest {
queryMapReq.setTopN(10);
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集"));
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
Assertions.assertNotNull(mapMeta);
Assertions.assertNotEquals(0, mapMeta.getDataSetMapInfo().get("超音数数据集").getMapFields());
@@ -39,7 +39,7 @@ public class MetaDiscoveryTest extends BaseTest {
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
queryMapReq.setQueryDataType(QueryDataType.TAG);
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
Assert.assertNotNull(mapMeta);
}
@@ -51,7 +51,7 @@ public class MetaDiscoveryTest extends BaseTest {
queryMapReq.setUser(User.getFakeUser());
queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
queryMapReq.setQueryDataType(QueryDataType.METRIC);
MapInfoResp mapMeta = chatQueryService.map(queryMapReq);
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
Assert.assertNotNull(mapMeta);
}
}