(improvement)(headless&chat)Move ChatContext from Headless module to Chat module.

This commit is contained in:
jerryjzhang
2024-07-12 16:56:25 +08:00
parent e365a36749
commit baff30550e
78 changed files with 399 additions and 459 deletions

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
public interface ChatExecutor {
QueryResult execute(ChatExecuteContext chatExecuteContext);
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
public interface ChatQueryExecutor {
QueryResult execute(ExecuteContext executeContext);
}

View File

@@ -4,7 +4,7 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -23,7 +23,7 @@ import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
public class PlainTextExecutor implements ChatExecutor {
public class PlainTextExecutor implements ChatQueryExecutor {
private static final String INSTRUCTION = ""
+ "#Role: You are a nice person to talk to.\n"
@@ -34,34 +34,34 @@ public class PlainTextExecutor implements ChatExecutor {
+ "#Your response: ";
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
if (!"PLAIN_TEXT".equals(chatExecuteContext.getParseInfo().getQueryMode())) {
public QueryResult execute(ExecuteContext executeContext) {
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
return null;
}
String promptStr = String.format(INSTRUCTION, getHistoryInputs(chatExecuteContext),
chatExecuteContext.getQueryText());
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
executeContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatAgent.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult();
result.setQueryState(QueryState.SUCCESS);
result.setQueryMode(chatExecuteContext.getParseInfo().getQueryMode());
result.setQueryMode(executeContext.getParseInfo().getQueryMode());
result.setTextResult(response.content().text());
return result;
}
private String getHistoryInputs(ChatExecuteContext chatExecuteContext) {
private String getHistoryInputs(ExecuteContext executeContext) {
StringBuilder historyInput = new StringBuilder();
AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(chatExecuteContext.getAgentId());
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
@@ -70,7 +70,7 @@ public class PlainTextExecutor implements ChatExecutor {
? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) {
List<ParseResp> parseResps = getHistoryParseResult(chatExecuteContext.getChatId(), 5);
List<ParseResp> parseResps = getHistoryParseResult(executeContext.getChatId(), 5);
parseResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");

View File

@@ -2,15 +2,15 @@ package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.plugin.build.PluginSemanticQuery;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
public class PluginExecutor implements ChatExecutor {
public class PluginExecutor implements ChatQueryExecutor {
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
public QueryResult execute(ExecuteContext executeContext) {
SemanticParseInfo parseInfo = executeContext.getParseInfo();
if (!PluginQueryManager.isPluginQuery(parseInfo.getQueryMode())) {
return null;
}

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.chat.server.executor;
import com.tencent.supersonic.chat.api.pojo.enums.MemoryStatus;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ResultFormatter;
import com.tencent.supersonic.common.pojo.QueryColumn;
@@ -12,10 +12,10 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.chat.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlQuery;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.web.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import lombok.SneakyThrows;
import org.apache.commons.lang3.StringUtils;
@@ -25,12 +25,12 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
public class SqlExecutor implements ChatExecutor {
public class SqlExecutor implements ChatQueryExecutor {
@SneakyThrows
@Override
public QueryResult execute(ChatExecuteContext chatExecuteContext) {
QueryResult queryResult = doExecute(chatExecuteContext);
public QueryResult execute(ExecuteContext executeContext) {
QueryResult queryResult = doExecute(executeContext);
if (queryResult != null) {
String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
@@ -41,13 +41,13 @@ public class SqlExecutor implements ChatExecutor {
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory(ChatMemoryDO.builder()
.agentId(chatExecuteContext.getAgentId())
.agentId(executeContext.getAgent().getId())
.status(MemoryStatus.PENDING)
.question(chatExecuteContext.getQueryText())
.s2sql(chatExecuteContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(chatExecuteContext.getParseInfo()))
.createdBy(chatExecuteContext.getUser().getName())
.updatedBy(chatExecuteContext.getUser().getName())
.question(executeContext.getQueryText())
.s2sql(executeContext.getParseInfo().getSqlInfo().getParsedS2SQL())
.dbSchema(buildSchemaStr(executeContext.getParseInfo()))
.createdBy(executeContext.getUser().getName())
.updatedBy(executeContext.getUser().getName())
.createdAt(new Date())
.build());
}
@@ -57,12 +57,12 @@ public class SqlExecutor implements ChatExecutor {
}
@SneakyThrows
private QueryResult doExecute(ChatExecuteContext chatExecuteContext) {
private QueryResult doExecute(ExecuteContext executeContext) {
SemanticLayerService semanticLayer = ContextUtils.getBean(SemanticLayerService.class);
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(chatExecuteContext.getChatId());
SemanticParseInfo parseInfo = chatExecuteContext.getParseInfo();
ChatContext chatCtx = chatContextService.getOrCreateContext(executeContext.getChatId());
SemanticParseInfo parseInfo = executeContext.getParseInfo();
if (Objects.isNull(parseInfo.getSqlInfo())
|| StringUtils.isBlank(parseInfo.getSqlInfo().getCorrectedS2SQL())) {
return null;
@@ -74,8 +74,10 @@ public class SqlExecutor implements ChatExecutor {
sqlReq.setSqlInfo(parseInfo.getSqlInfo());
sqlReq.setDataSetId(parseInfo.getDataSetId());
long startTime = System.currentTimeMillis();
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, chatExecuteContext.getUser());
SemanticQueryResp queryResp = semanticLayer.queryByReq(sqlReq, executeContext.getUser());
QueryResult queryResult = new QueryResult();
queryResult.setChatContext(parseInfo);
queryResult.setQueryMode(parseInfo.getQueryMode());
if (queryResp != null) {
queryResult.setQueryAuthorization(queryResp.getQueryAuthorization());
List<Map<String, Object>> resultList = queryResp == null ? new ArrayList<>()
@@ -85,7 +87,6 @@ public class SqlExecutor implements ChatExecutor {
queryResult.setQuerySql(queryResp.getSql());
queryResult.setQueryResults(resultList);
queryResult.setQueryColumns(columns);
queryResult.setQueryMode(parseInfo.getQueryMode());
queryResult.setQueryState(QueryState.SUCCESS);
chatCtx.setParseInfo(parseInfo);
@@ -94,7 +95,6 @@ public class SqlExecutor implements ChatExecutor {
queryResult.setQueryState(QueryState.INVALID);
queryResult.setQueryMode(parseInfo.getQueryMode());
}
queryResult.setChatContext(chatCtx.getParseInfo());
return queryResult;
}

View File

@@ -1,10 +0,0 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ChatParser {
void parse(ChatParseContext chatParseContext, ParseResp parseResp);
}

View File

@@ -0,0 +1,10 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ChatQueryParser {
void parse(ParseContext parseContext, ParseResp parseResp);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -9,18 +9,18 @@ import lombok.extern.slf4j.Slf4j;
import java.util.List;
@Slf4j
public class NL2PluginParser implements ChatParser {
public class NL2PluginParser implements ChatQueryParser {
private final List<PluginRecognizer> pluginRecognizers = ComponentFactory.getPluginRecognizers();
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.getAgent().containsPluginTool()) {
public void parse(ParseContext parseContext, ParseResp parseResp) {
if (!parseContext.getAgent().containsPluginTool()) {
return;
}
pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer.recognize(chatParseContext, parseResp);
pluginRecognizer.recognize(parseContext, parseResp);
log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
JsonUtil.toString(parseResp));
});

View File

@@ -3,7 +3,8 @@ package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
@@ -17,7 +18,8 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
@@ -42,7 +44,7 @@ import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
@Slf4j
public class NL2SQLParser implements ChatParser {
public class NL2SQLParser implements ChatQueryParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
@@ -73,27 +75,30 @@ public class NL2SQLParser implements ChatParser {
+ "#Response: ";
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
public void parse(ParseContext parseContext, ParseResp parseResp) {
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
return;
}
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx = chatContextService.getOrCreateContext(parseContext.getChatId());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
chatParseContext.getAgent().getModelConfig());
parseContext.getAgent().getModelConfig());
processMultiTurn(chatLanguageModel, chatParseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryNLReq);
processMultiTurn(chatLanguageModel, parseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
ParseResp text2SqlParseResp = chatLayerService.performParsing(queryNLReq);
if (ParseResp.ParseState.COMPLETED.equals(text2SqlParseResp.getState())) {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else {
parseResp.setErrorMsg(rewriteErrorMessage(chatLanguageModel,
chatParseContext.getQueryText(),
parseContext.getQueryText(),
text2SqlParseResp.getErrorMsg(),
queryNLReq.getDynamicExemplars(),
chatParseContext.getAgent().getExamples()));
parseContext.getAgent().getExamples()));
}
parseResp.setState(text2SqlParseResp.getState());
parseResp.getParseTimeCost().setSqlTime(text2SqlParseResp.getParseTimeCost().getSqlTime());
@@ -155,9 +160,9 @@ public class NL2SQLParser implements ChatParser {
parseInfo.setTextInfo(textBuilder.toString());
}
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ChatParseContext chatParseContext) {
private void processMultiTurn(ChatLanguageModel chatLanguageModel, ParseContext parseContext) {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = agentMultiTurnConfig != null
@@ -167,11 +172,11 @@ public class NL2SQLParser implements ChatParser {
}
// derive mapping result of current question and parsing result of last question.
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp currentMapResult = chatQueryService.performMapping(queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp currentMapResult = chatLayerService.performMapping(queryNLReq);
List<ParseResp> historyParseResults = getHistoryParseResult(chatParseContext.getChatId(), 1);
List<ParseResp> historyParseResults = getHistoryParseResult(parseContext.getChatId(), 1);
if (historyParseResults.size() == 0) {
return;
}
@@ -196,7 +201,7 @@ public class NL2SQLParser implements ChatParser {
String rewrittenQuery = response.content().text();
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
chatParseContext.setQueryText(rewrittenQuery);
parseContext.setQueryText(rewrittenQuery);
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}",
lastParseResult.getQueryText(), currentMapResult.getQueryText(), rewrittenQuery);
}

View File

@@ -5,7 +5,7 @@ import com.tencent.supersonic.common.pojo.Parameter;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service("ChatParserConfig")
@Service("ChatQueryParserConfig")
@Slf4j
public class ParserConfig extends ParameterConfig {

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public class PlainTextParser implements ChatParser {
public class PlainTextParser implements ChatQueryParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
if (chatParseContext.getAgent().containsAnyTool()) {
public void parse(ParseContext parseContext, ParseResp parseResp) {
if (parseContext.getAgent().containsAnyTool()) {
return;
}

View File

@@ -0,0 +1,16 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import lombok.Data;
import java.io.Serializable;
import java.time.Instant;
@Data
public class ChatContextDO implements Serializable {
private Integer chatId;
private Instant modifiedAt;
private String user;
private String queryText;
private String semanticParse;
}

View File

@@ -0,0 +1,54 @@
package com.tencent.supersonic.chat.server.persistence.dataobject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import java.util.Date;
@Data
@Builder
@NoArgsConstructor
@Getter
@AllArgsConstructor
public class StatisticsDO {
/**
* questionId
*/
private Long questionId;
/**
* chatId
*/
private Long chatId;
/**
* createTime
*/
private Date createTime;
/**
* queryText
*/
private String queryText;
/**
* userName
*/
private String userName;
/**
* interface
*/
private String interfaceName;
/**
* cost
*/
private Integer cost;
private Integer type;
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatContextMapper {
ChatContextDO getContextByChatId(Integer chatId);
int updateContext(ChatContextDO contextDO);
int addContext(ChatContextDO contextDO);
}

View File

@@ -0,0 +1,12 @@
package com.tencent.supersonic.chat.server.persistence.mapper;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import java.util.List;
@Mapper
public interface StatisticsMapper {
boolean batchSaveStatistics(@Param("list") List<StatisticsDO> list);
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.server.persistence.repository;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
public interface ChatContextRepository {
ChatContext getOrCreateContext(Integer chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -0,0 +1,71 @@
package com.tencent.supersonic.chat.server.persistence.repository.impl;
import com.google.gson.Gson;
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO;
import com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.context.annotation.Primary;
import org.springframework.stereotype.Repository;
@Repository
@Primary
@Slf4j
public class ChatContextRepositoryImpl implements ChatContextRepository {
private final ChatContextMapper chatContextMapper;
public ChatContextRepositoryImpl(ChatContextMapper chatContextMapper) {
this.chatContextMapper = chatContextMapper;
}
@Override
public ChatContext getOrCreateContext(Integer chatId) {
ChatContextDO context = chatContextMapper.getContextByChatId(chatId);
if (context == null) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(chatId);
return chatContext;
}
return cast(context);
}
@Override
public void updateContext(ChatContext chatCtx) {
ChatContextDO context = cast(chatCtx);
if (chatContextMapper.getContextByChatId(chatCtx.getChatId()) == null) {
chatContextMapper.addContext(context);
} else {
chatContextMapper.updateContext(context);
}
}
private ChatContext cast(ChatContextDO contextDO) {
ChatContext chatContext = new ChatContext();
chatContext.setChatId(contextDO.getChatId());
chatContext.setUser(contextDO.getUser());
chatContext.setQueryText(contextDO.getQueryText());
if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) {
SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(),
SemanticParseInfo.class);
chatContext.setParseInfo(semanticParseInfo);
}
return chatContext;
}
private ChatContextDO cast(ChatContext chatContext) {
ChatContextDO chatContextDO = new ChatContextDO();
chatContextDO.setChatId(chatContext.getChatId());
chatContextDO.setQueryText(chatContext.getQueryText());
chatContextDO.setUser(chatContext.getUser());
if (chatContext.getParseInfo() != null) {
Gson g = new Gson();
chatContextDO.setSemanticParse(g.toJson(chatContext.getParseInfo()));
}
return chatContextDO;
}
}

View File

@@ -183,7 +183,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
public List<ParseResp> getContextualParseInfo(Integer chatId) {
List<ChatParseDO> chatParseDOList = chatParseMapper.getContextualParseInfo(chatId);
List<ParseResp> semanticParseInfoList = chatParseDOList.stream().map(parseInfo -> {
ParseResp parseResp = new ParseResp(chatId, parseInfo.getQueryText());
ParseResp parseResp = new ParseResp(parseInfo.getQueryText());
List<SemanticParseInfo> selectedParses = new ArrayList<>();
selectedParses.add(JSONObject.parseObject(parseInfo.getParseInfo(), SemanticParseInfo.class));
parseResp.setSelectedParses(selectedParses);

View File

@@ -11,7 +11,7 @@ import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.plugin.event.PluginAddEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService;
@@ -52,9 +52,9 @@ public class PluginManager {
@Autowired
private EmbeddingService embeddingService;
public static List<ChatPlugin> getPluginAgentCanSupport(ChatParseContext chatParseContext) {
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
PluginService pluginService = ContextUtils.getBean(PluginService.class);
Agent agent = chatParseContext.getAgent();
Agent agent = parseContext.getAgent();
List<ChatPlugin> plugins = pluginService.getPluginList();
if (Objects.isNull(agent)) {
return plugins;
@@ -191,9 +191,9 @@ public class PluginManager {
return String.valueOf(Integer.parseInt(id) / 1000);
}
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ChatParseContext chatParseContext) {
SchemaMapInfo schemaMapInfo = chatParseContext.getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, chatParseContext);
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
return Pair.of(false, Sets.newHashSet());
}
@@ -259,8 +259,8 @@ public class PluginManager {
.collect(Collectors.toList());
}
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ChatParseContext chatParseContext) {
Set<Long> matchedDataSets = chatParseContext.getMapInfo().getMatchedDataSetInfos();
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
if (plugin.isContainsAllDataSet()) {
return Sets.newHashSet(plugin.getDefaultMode());
}

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -28,22 +28,22 @@ import java.util.Set;
*/
public abstract class PluginRecognizer {
public void recognize(ChatParseContext chatParseContext, ParseResp parseResp) {
if (!checkPreCondition(chatParseContext)) {
public void recognize(ParseContext parseContext, ParseResp parseResp) {
if (!checkPreCondition(parseContext)) {
return;
}
PluginRecallResult pluginRecallResult = recallPlugin(chatParseContext);
PluginRecallResult pluginRecallResult = recallPlugin(parseContext);
if (pluginRecallResult == null) {
return;
}
buildQuery(chatParseContext, parseResp, pluginRecallResult);
buildQuery(parseContext, parseResp, pluginRecallResult);
}
public abstract boolean checkPreCondition(ChatParseContext chatParseContext);
public abstract boolean checkPreCondition(ParseContext parseContext);
public abstract PluginRecallResult recallPlugin(ChatParseContext chatParseContext);
public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
public void buildQuery(ChatParseContext chatParseContext, ParseResp parseResp,
public void buildQuery(ParseContext parseContext, ParseResp parseResp,
PluginRecallResult pluginRecallResult) {
ChatPlugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
@@ -52,21 +52,21 @@ public abstract class PluginRecognizer {
}
for (Long dataSetId : dataSetIds) {
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
chatParseContext, pluginRecallResult.getDistance());
parseContext, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(plugin.getType());
semanticParseInfo.setScore(pluginRecallResult.getScore());
parseResp.getSelectedParses().add(semanticParseInfo);
}
}
protected List<ChatPlugin> getPluginList(ChatParseContext chatParseContext) {
return PluginManager.getPluginAgentCanSupport(chatParseContext);
protected List<ChatPlugin> getPluginList(ParseContext parseContext) {
return PluginManager.getPluginAgentCanSupport(parseContext);
}
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
ChatParseContext chatParseContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = chatParseContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = chatParseContext.getQueryFilters();
ParseContext parseContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = parseContext.getQueryFilters();
if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList();
}
@@ -80,7 +80,7 @@ public abstract class PluginRecognizer {
pluginParseResult.setPlugin(plugin);
pluginParseResult.setQueryFilters(queryFilters);
pluginParseResult.setDistance(distance);
pluginParseResult.setQueryText(chatParseContext.getQueryText());
pluginParseResult.setQueryText(parseContext.getQueryText());
properties.put(Constants.CONTEXT, pluginParseResult);
properties.put("type", "plugin");
properties.put("name", plugin.getName());

View File

@@ -6,7 +6,7 @@ import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.util.ContextUtils;
import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQueryResult;
@@ -26,25 +26,25 @@ import java.util.stream.Collectors;
@Slf4j
public class EmbeddingRecallRecognizer extends PluginRecognizer {
public boolean checkPreCondition(ChatParseContext chatParseContext) {
List<ChatPlugin> plugins = getPluginList(chatParseContext);
public boolean checkPreCondition(ParseContext parseContext) {
List<ChatPlugin> plugins = getPluginList(parseContext);
return !CollectionUtils.isEmpty(plugins);
}
public PluginRecallResult recallPlugin(ChatParseContext chatParseContext) {
String text = chatParseContext.getQueryText();
public PluginRecallResult recallPlugin(ParseContext parseContext) {
String text = parseContext.getQueryText();
List<Retrieval> embeddingRetrievals = embeddingRecall(text);
if (CollectionUtils.isEmpty(embeddingRetrievals)) {
return null;
}
List<ChatPlugin> plugins = getPluginList(chatParseContext);
List<ChatPlugin> plugins = getPluginList(parseContext);
Map<Long, ChatPlugin> pluginMap = plugins.stream().collect(Collectors.toMap(ChatPlugin::getId, p -> p));
for (Retrieval embeddingRetrieval : embeddingRetrievals) {
ChatPlugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId()));
if (plugin == null) {
continue;
}
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, chatParseContext);
Pair<Boolean, Set<Long>> pair = PluginManager.resolve(plugin, parseContext);
log.info("embedding plugin resolve: {}", pair);
if (pair.getLeft()) {
Set<Long> dataSetList = pair.getRight();
@@ -53,7 +53,7 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = chatParseContext.getQueryText().length() * (1 - distance);
double score = parseContext.getQueryText().length() * (1 - distance);
return PluginRecallResult.builder()
.plugin(plugin).dataSetIds(dataSetList).score(score).distance(distance).build();
}

View File

@@ -0,0 +1,13 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
@Data
public class ChatContext {
private Integer chatId;
private String queryText;
private SemanticParseInfo parseInfo = new SemanticParseInfo();
private String user;
}

View File

@@ -1,17 +1,17 @@
package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import lombok.Data;
@Data
public class ChatExecuteContext {
public class ExecuteContext {
private User user;
private Integer agentId;
private Long queryId;
private Integer chatId;
private int parseId;
private String queryText;
private Agent agent;
private Integer chatId;
private Long queryId;
private boolean saveAnswer;
private SemanticParseInfo parseInfo;
}

View File

@@ -7,14 +7,14 @@ import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data;
@Data
public class ChatParseContext {
private String queryText;
private Integer chatId;
private Agent agent;
public class ParseContext {
private User user;
private String queryText;
private Agent agent;
private Integer chatId;
private QueryFilters queryFilters;
private boolean saveAnswer = true;
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SchemaMapInfo mapInfo;
public boolean enableNL2SQL() {
if (agent == null) {

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
@@ -28,8 +28,8 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
private static final int recommend_dimension_size = 5;
@Override
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
public void process(ExecuteContext executeContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
if (!QueryType.METRIC.equals(semanticParseInfo.getQueryType())
|| CollectionUtils.isEmpty(semanticParseInfo.getMetrics())) {
return;

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.processor.ResultProcessor;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -9,6 +9,6 @@ import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
*/
public interface ExecuteResultProcessor extends ResultProcessor {
void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult);
void process(ExecuteContext executeContext, QueryResult queryResult);
}

View File

@@ -11,7 +11,7 @@ import static com.tencent.supersonic.common.pojo.Constants.TIME_FORMAT;
import static com.tencent.supersonic.common.pojo.Constants.WEEK;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.DateConf.DateMode;
import com.tencent.supersonic.common.pojo.QueryColumn;
@@ -60,15 +60,15 @@ import org.springframework.util.CollectionUtils;
public class MetricRatioProcessor implements ExecuteResultProcessor {
@Override
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = chatExecuteContext.getParseInfo();
public void process(ExecuteContext executeContext, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo = executeContext.getParseInfo();
AggregatorConfig aggregatorConfig = ContextUtils.getBean(AggregatorConfig.class);
if (CollectionUtils.isEmpty(semanticParseInfo.getMetrics())
|| !aggregatorConfig.getEnableRatio()
|| !QueryType.METRIC.equals(semanticParseInfo.getQueryType())) {
return;
}
AggregateInfo aggregateInfo = getAggregateInfo(chatExecuteContext.getUser(),
AggregateInfo aggregateInfo = getAggregateInfo(executeContext.getUser(),
semanticParseInfo, queryResult);
queryResult.setAggregateInfo(aggregateInfo);
}

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.execute;
import com.alibaba.fastjson.JSONObject;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.DictWordType;
import com.tencent.supersonic.common.pojo.enums.QueryType;
@@ -34,8 +34,8 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
private static final int METRIC_RECOMMEND_SIZE = 5;
@Override
public void process(ChatExecuteContext chatExecuteContext, QueryResult queryResult) {
fillSimilarMetric(chatExecuteContext.getParseInfo());
public void process(ExecuteContext executeContext, QueryResult queryResult) {
fillSimilarMetric(executeContext.getParseInfo());
}
private void fillSimilarMetric(SemanticParseInfo parseInfo) {

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.EntityInfo;
@@ -19,7 +19,7 @@ import java.util.List;
public class EntityInfoProcessor implements ParseResultProcessor {
@Override
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
public void process(ParseContext parseContext, ParseResp parseResp) {
List<SemanticParseInfo> selectedParses = parseResp.getSelectedParses();
if (CollectionUtils.isEmpty(selectedParses)) {
return;
@@ -33,7 +33,7 @@ public class EntityInfoProcessor implements ParseResultProcessor {
//1. set entity info
SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class);
DataSetSchema dataSetSchema = semanticService.getDataSetSchema(parseInfo.getDataSetId());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, chatParseContext.getUser());
EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, parseContext.getUser());
if (QueryManager.isTagQuery(queryMode)
|| QueryManager.isMetricQuery(queryMode)) {
parseInfo.setEntityInfo(entityInfo);

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
public interface ParseResultProcessor {
void process(ChatParseContext chatParseContext, ParseResp parseResp);
void process(ParseContext parseContext, ParseResp parseResp);
}

View File

@@ -5,7 +5,7 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.tencent.supersonic.chat.api.pojo.response.SimilarQueryRecallResp;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
@@ -25,15 +25,15 @@ import java.util.stream.Collectors;
public class QueryRecommendProcessor implements ParseResultProcessor {
@Override
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
CompletableFuture.runAsync(() -> doProcess(parseResp, chatParseContext));
public void process(ParseContext parseContext, ParseResp parseResp) {
CompletableFuture.runAsync(() -> doProcess(parseResp, parseContext));
}
@SneakyThrows
private void doProcess(ParseResp parseResp, ChatParseContext chatParseContext) {
private void doProcess(ParseResp parseResp, ParseContext parseContext) {
Long queryId = parseResp.getQueryId();
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(chatParseContext.getQueryText(),
chatParseContext.getAgent().getId());
List<SimilarQueryRecallResp> solvedQueries = getSimilarQueries(parseContext.getQueryText(),
parseContext.getAgent().getId());
ChatQueryDO chatQueryDO = getChatQuery(queryId);
chatQueryDO.setSimilarQueries(JSONObject.toJSONString(solvedQueries));
updateChatQuery(chatQueryDO);

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
@@ -12,7 +12,7 @@ import lombok.extern.slf4j.Slf4j;
public class TimeCostProcessor implements ParseResultProcessor {
@Override
public void process(ChatParseContext chatParseContext, ParseResp parseResp) {
public void process(ParseContext parseContext, ParseResp parseResp) {
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp.getParseTimeCost().setParseTime(
System.currentTimeMillis() - parseStartTime - parseResp.getParseTimeCost().getSqlTime());

View File

@@ -6,11 +6,10 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.BeanUtils;
@@ -32,20 +31,20 @@ import javax.validation.Valid;
public class ChatQueryController {
@Autowired
private ChatService chatService;
private ChatQueryService chatQueryService;
@PostMapping("search")
public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
HttpServletResponse response) {
chatParseReq.setUser(UserHolder.findUser(request, response));
return chatService.search(chatParseReq);
return chatQueryService.search(chatParseReq);
}
@PostMapping("parse")
public Object parse(@RequestBody ChatParseReq chatParseReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
chatParseReq.setUser(UserHolder.findUser(request, response));
return chatService.performParsing(chatParseReq);
return chatQueryService.performParsing(chatParseReq);
}
@PostMapping("execute")
@@ -53,7 +52,7 @@ public class ChatQueryController {
HttpServletRequest request, HttpServletResponse response)
throws Exception {
chatExecuteReq.setUser(UserHolder.findUser(request, response));
return chatService.performExecution(chatExecuteReq);
return chatQueryService.performExecution(chatExecuteReq);
}
@PostMapping("/")
@@ -62,7 +61,7 @@ public class ChatQueryController {
throws Exception {
User user = UserHolder.findUser(request, response);
chatParseReq.setUser(user);
ParseResp parseResp = chatService.performParsing(chatParseReq);
ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
throw new InvalidArgumentException("parser error,no selectedParses");
@@ -72,27 +71,20 @@ public class ChatQueryController {
BeanUtils.copyProperties(chatParseReq, chatExecuteReq);
chatExecuteReq.setQueryId(parseResp.getQueryId());
chatExecuteReq.setParseId(semanticParseInfo.getId());
return chatService.performExecution(chatExecuteReq);
}
@PostMapping("queryContext")
public Object queryContext(@RequestBody QueryNLReq queryCtx,
HttpServletRequest request, HttpServletResponse response) {
queryCtx.setUser(UserHolder.findUser(request, response));
return chatService.queryContext(queryCtx.getChatId());
return chatQueryService.performExecution(chatExecuteReq);
}
@PostMapping("queryData")
public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
chatQueryDataReq.setUser(UserHolder.findUser(request, response));
return chatService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
}
@PostMapping("queryDimensionValue")
public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
return chatService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
return chatQueryService.queryDimensionValue(dimensionValueReq, UserHolder.findUser(request, response));
}
}

View File

@@ -0,0 +1,11 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
public interface ChatContextService {
ChatContext getOrCreateContext(Integer chatId);
void updateContext(ChatContext chatCtx);
}

View File

@@ -4,7 +4,6 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
@@ -12,7 +11,7 @@ import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import java.util.List;
public interface ChatService {
public interface ChatQueryService {
List<SearchResult> search(ChatParseReq chatParseReq);
@@ -24,8 +23,6 @@ public interface ChatService {
Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exception;
SemanticParseInfo queryContext(Integer chatId);
Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception;
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.chat.server.service;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import java.util.List;

View File

@@ -9,7 +9,7 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.ChatModelConfig;
@@ -36,7 +36,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
private MemoryService memoryService;
@Autowired
private ChatService chatService;
private ChatQueryService chatQueryService;
private ExecutorService executorService = Executors.newFixedThreadPool(1);
@@ -103,7 +103,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
continue;
}
try {
chatService.parseAndExecute(-1, agent.getId(), example);
chatQueryService.parseAndExecute(-1, agent.getId(), example);
} catch (Exception e) {
log.warn("agent:{} example execute failed:{}", agent.getName(), example);
}

View File

@@ -0,0 +1,30 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.persistence.repository.ChatContextRepository;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Slf4j
@Service
public class ChatContextServiceImpl implements ChatContextService {
private ChatContextRepository chatContextRepository;
public ChatContextServiceImpl(ChatContextRepository chatContextRepository) {
this.chatContextRepository = chatContextRepository;
}
@Override
public ChatContext getOrCreateContext(Integer chatId) {
return chatContextRepository.getOrCreateContext(chatId);
}
@Override
public void updateContext(ChatContext chatCtx) {
log.debug("save ChatContext {}", chatCtx);
chatContextRepository.updateContext(chatCtx);
}
}

View File

@@ -6,15 +6,16 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatQueryDataReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
import com.tencent.supersonic.chat.server.parser.ChatParser;
import com.tencent.supersonic.chat.server.pojo.ChatExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatContextService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.util.ComponentFactory;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.util.BeanMapper;
@@ -26,7 +27,7 @@ import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
import com.tencent.supersonic.headless.server.facade.service.ChatQueryService;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import com.tencent.supersonic.headless.server.facade.service.RetrieveService;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import lombok.extern.slf4j.Slf4j;
@@ -39,49 +40,51 @@ import java.util.List;
@Slf4j
@Service
public class ChatServiceImpl implements ChatService {
public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired
private ChatManageService chatManageService;
@Autowired
private ChatQueryService chatQueryService;
private ChatLayerService chatLayerService;
@Autowired
private RetrieveService retrieveService;
@Autowired
private AgentService agentService;
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private ChatContextService chatContextService;
private List<ChatParser> chatParsers = ComponentFactory.getChatParsers();
private List<ChatExecutor> chatExecutors = ComponentFactory.getChatExecutors();
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
private List<ParseResultProcessor> parseResultProcessors = ComponentFactory.getParseProcessors();
private List<ExecuteResultProcessor> executeResultProcessors = ComponentFactory.getExecuteProcessors();
@Override
public List<SearchResult> search(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
Agent agent = chatParseContext.getAgent();
ParseContext parseContext = buildParseContext(chatParseReq);
Agent agent = parseContext.getAgent();
if (!agent.enableSearch()) {
return Lists.newArrayList();
}
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
return retrieveService.retrieve(queryNLReq);
}
@Override
public ParseResp performParsing(ChatParseReq chatParseReq) {
ParseResp parseResp = new ParseResp(chatParseReq.getChatId(), chatParseReq.getQueryText());
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp);
ChatParseContext chatParseContext = buildParseContext(chatParseReq);
supplyMapInfo(chatParseContext);
for (ChatParser chatParser : chatParsers) {
chatParser.parse(chatParseContext, parseResp);
ParseContext parseContext = buildParseContext(chatParseReq);
supplyMapInfo(parseContext);
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
chatQueryParser.parse(parseContext, parseResp);
}
for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(chatParseContext, parseResp);
processor.process(parseContext, parseResp);
}
chatParseReq.setQueryText(chatParseContext.getQueryText());
parseResp.setQueryText(chatParseContext.getQueryText());
chatParseReq.setQueryText(parseContext.getQueryText());
parseResp.setQueryText(parseContext.getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp);
chatManageService.updateParseCostTime(parseResp);
return parseResp;
@@ -90,9 +93,9 @@ public class ChatServiceImpl implements ChatService {
@Override
public QueryResult performExecution(ChatExecuteReq chatExecuteReq) {
QueryResult queryResult = new QueryResult();
ChatExecuteContext chatExecuteContext = buildExecuteContext(chatExecuteReq);
for (ChatExecutor chatExecutor : chatExecutors) {
queryResult = chatExecutor.execute(chatExecuteContext);
ExecuteContext executeContext = buildExecuteContext(chatExecuteReq);
for (ChatQueryExecutor chatQueryExecutor : chatQueryExecutors) {
queryResult = chatQueryExecutor.execute(executeContext);
if (queryResult != null) {
break;
}
@@ -100,7 +103,7 @@ public class ChatServiceImpl implements ChatService {
if (queryResult != null) {
for (ExecuteResultProcessor processor : executeResultProcessors) {
processor.process(chatExecuteContext, queryResult);
processor.process(executeContext, queryResult);
}
saveQueryResult(chatExecuteReq, queryResult);
}
@@ -125,34 +128,36 @@ public class ChatServiceImpl implements ChatService {
executeReq.setQueryId(parseResp.getQueryId());
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setQueryText(queryText);
executeReq.setChatId(parseResp.getChatId());
executeReq.setChatId(chatId);
executeReq.setUser(User.getFakeUser());
executeReq.setAgentId(agentId);
executeReq.setSaveAnswer(true);
return performExecution(executeReq);
}
private ChatParseContext buildParseContext(ChatParseReq chatParseReq) {
ChatParseContext chatParseContext = new ChatParseContext();
BeanMapper.mapper(chatParseReq, chatParseContext);
private ParseContext buildParseContext(ChatParseReq chatParseReq) {
ParseContext parseContext = new ParseContext();
BeanMapper.mapper(chatParseReq, parseContext);
Agent agent = agentService.getAgent(chatParseReq.getAgentId());
chatParseContext.setAgent(agent);
return chatParseContext;
parseContext.setAgent(agent);
return parseContext;
}
private void supplyMapInfo(ChatParseContext chatParseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
MapResp mapResp = chatQueryService.performMapping(queryNLReq);
chatParseContext.setMapInfo(mapResp.getMapInfo());
private void supplyMapInfo(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp mapResp = chatLayerService.performMapping(queryNLReq);
parseContext.setMapInfo(mapResp.getMapInfo());
}
private ChatExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ChatExecuteContext chatExecuteContext = new ChatExecuteContext();
BeanMapper.mapper(chatExecuteReq, chatExecuteContext);
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ExecuteContext executeContext = new ExecuteContext();
BeanMapper.mapper(chatExecuteReq, executeContext);
SemanticParseInfo parseInfo = chatManageService.getParseInfo(
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
chatExecuteContext.setParseInfo(parseInfo);
return chatExecuteContext;
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
executeContext.setAgent(agent);
executeContext.setParseInfo(parseInfo);
return executeContext;
}
@Override
@@ -163,12 +168,7 @@ public class ChatServiceImpl implements ChatService {
QueryDataReq queryData = new QueryDataReq();
BeanMapper.mapper(chatQueryDataReq, queryData);
queryData.setParseInfo(parseInfo);
return chatQueryService.executeDirectQuery(queryData, user);
}
@Override
public SemanticParseInfo queryContext(Integer chatId) {
return chatQueryService.queryContext(chatId);
return chatLayerService.executeDirectQuery(queryData, user);
}
@Override
@@ -176,7 +176,7 @@ public class ChatServiceImpl implements ChatService {
Integer agentId = dimensionValueReq.getAgentId();
Agent agent = agentService.getAgent(agentId);
dimensionValueReq.setDataSetIds(agent.getDataSetIds());
return chatQueryService.queryDimensionValue(dimensionValueReq, user);
return chatLayerService.queryDimensionValue(dimensionValueReq, user);
}
public void saveQueryResult(ChatExecuteReq chatExecuteReq, QueryResult queryResult) {

View File

@@ -1,8 +1,8 @@
package com.tencent.supersonic.chat.server.service.impl;
import com.tencent.supersonic.headless.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO;
import com.tencent.supersonic.chat.server.service.StatisticsService;
import com.tencent.supersonic.headless.server.persistence.mapper.StatisticsMapper;
import com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.executor.ChatExecutor;
import com.tencent.supersonic.chat.server.parser.ChatParser;
import com.tencent.supersonic.chat.server.executor.ChatQueryExecutor;
import com.tencent.supersonic.chat.server.parser.ChatQueryParser;
import com.tencent.supersonic.chat.server.plugin.recognize.PluginRecognizer;
import com.tencent.supersonic.chat.server.processor.execute.ExecuteResultProcessor;
import com.tencent.supersonic.chat.server.processor.parse.ParseResultProcessor;
@@ -16,8 +16,8 @@ import java.util.List;
public class ComponentFactory {
private static List<ParseResultProcessor> parseProcessors = new ArrayList<>();
private static List<ExecuteResultProcessor> executeProcessors = new ArrayList<>();
private static List<ChatParser> chatParsers = new ArrayList<>();
private static List<ChatExecutor> chatExecutors = new ArrayList<>();
private static List<ChatQueryParser> chatQueryParsers = new ArrayList<>();
private static List<ChatQueryExecutor> chatQueryExecutors = new ArrayList<>();
private static List<PluginRecognizer> pluginRecognizers = new ArrayList<>();
public static List<ParseResultProcessor> getParseProcessors() {
@@ -30,14 +30,14 @@ public class ComponentFactory {
? init(ExecuteResultProcessor.class, executeProcessors) : executeProcessors;
}
public static List<ChatParser> getChatParsers() {
return CollectionUtils.isEmpty(chatParsers)
? init(ChatParser.class, chatParsers) : chatParsers;
public static List<ChatQueryParser> getChatParsers() {
return CollectionUtils.isEmpty(chatQueryParsers)
? init(ChatQueryParser.class, chatQueryParsers) : chatQueryParsers;
}
public static List<ChatExecutor> getChatExecutors() {
return CollectionUtils.isEmpty(chatExecutors)
? init(ChatExecutor.class, chatExecutors) : chatExecutors;
public static List<ChatQueryExecutor> getChatExecutors() {
return CollectionUtils.isEmpty(chatQueryExecutors)
? init(ChatQueryExecutor.class, chatQueryExecutors) : chatQueryExecutors;
}
public static List<PluginRecognizer> getPluginRecognizers() {

View File

@@ -1,20 +1,25 @@
package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
import org.apache.commons.collections.MapUtils;
import java.util.Objects;
public class QueryReqConverter {
public static QueryNLReq buildText2SqlQueryReq(ChatParseContext chatParseContext) {
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
return buildText2SqlQueryReq(parseContext, null);
}
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(chatParseContext, queryNLReq);
Agent agent = chatParseContext.getAgent();
BeanMapper.mapper(parseContext, queryNLReq);
Agent agent = parseContext.getAgent();
if (agent == null) {
return queryNLReq;
}
@@ -39,6 +44,9 @@ public class QueryReqConverter {
}
queryNLReq.setModelConfig(agent.getModelConfig());
queryNLReq.setPromptConfig(agent.getPromptConfig());
if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
return queryNLReq;
}

View File

@@ -0,0 +1,30 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.ChatContextMapper">
<resultMap id="ChatContextDO"
type="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO">
<id column="chat_id" property="chatId"/>
<result column="modified_at" property="modifiedAt"/>
<result column="user" property="user"/>
<result column="query_text" property="queryText"/>
<result column="semantic_parse" property="semanticParse"/>
<!--<result column="ext_data" property="extData"/>-->
</resultMap>
<select id="getContextByChatId" resultMap="ChatContextDO">
select *
from s2_chat_context where chat_id=#{chatId} limit 1
</select>
<insert id="addContext" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.ChatContextDO" >
insert into s2_chat_context (chat_id,user,query_text,semantic_parse) values (#{chatId}, #{user},#{queryText}, #{semanticParse})
</insert>
<update id="updateContext">
update s2_chat_context set query_text=#{queryText},semantic_parse=#{semanticParse} where chat_id=#{chatId}
</update>
</mapper>

View File

@@ -0,0 +1,28 @@
<?xml version="1.0" encoding="UTF-8" ?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.tencent.supersonic.chat.server.persistence.mapper.StatisticsMapper">
<resultMap id="Statistics" type="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
<id column="question_id" property="questionId"/>
<result column="chat_id" property="chatId"/>
<result column="user_name" property="userName"/>
<result column="query_text" property="queryText"/>
<result column="interface_name" property="interfaceName"/>
<result column="cost" property="cost"/>
<result column="type" property="type"/>
<result column="create_time" property="createTime"/>
</resultMap>
<insert id="batchSaveStatistics" parameterType="com.tencent.supersonic.chat.server.persistence.dataobject.StatisticsDO">
insert into s2_chat_statistics
(question_id,chat_id, user_name, query_text, interface_name,cost,type ,create_time)
values
<foreach collection="list" item="item" index="index" separator=",">
(#{item.questionId}, #{item.chatId}, #{item.userName}, #{item.queryText}, #{item.interfaceName}, #{item.cost}, #{item.type},#{item.createTime})
</foreach>
</insert>
</mapper>