[improvement][project]Remove unnecessary SchemaMapInfo from ParseContext.

This commit is contained in:
jerryjzhang
2024-10-27 15:14:06 +08:00
parent 1e3daffade
commit 397b527bc6
9 changed files with 44 additions and 46 deletions

View File

@@ -100,7 +100,7 @@ public class NL2SQLParser implements ChatQueryParser {
if (!parseContext.isDisableLLM()) { if (!parseContext.isDisableLLM()) {
processMultiTurn(parseContext); processMultiTurn(parseContext);
} }
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx);
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq); addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
@@ -179,11 +179,11 @@ public class NL2SQLParser implements ChatQueryParser {
// derive mapping result of current question and parsing result of last question. // derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
MapResp currentMapResult = chatLayerService.map(queryNLReq); MapResp currentMapResult = chatLayerService.map(queryNLReq);
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1); List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
if (historyQueries.size() == 0) { if (historyQueries.isEmpty()) {
return; return;
} }
QueryResp lastQuery = historyQueries.get(0); QueryResp lastQuery = historyQueries.get(0);
@@ -209,9 +209,6 @@ public class NL2SQLParser implements ChatQueryParser {
String rewrittenQuery = response.content().text(); String rewrittenQuery = response.content().text();
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response); keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
parseContext.setQueryText(rewrittenQuery); parseContext.setQueryText(rewrittenQuery);
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo());
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
currentMapResult.getQueryText(), rewrittenQuery); currentMapResult.getQueryText(), rewrittenQuery);
} }

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent; import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.service.EmbeddingService; import com.tencent.supersonic.common.service.EmbeddingService;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
@@ -20,6 +21,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.store.embedding.Retrieval; import dev.langchain4j.store.embedding.Retrieval;
import dev.langchain4j.store.embedding.RetrieveQuery; import dev.langchain4j.store.embedding.RetrieveQuery;
@@ -193,8 +196,10 @@ public class PluginManager {
} }
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) { public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo(); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, schemaMapInfo);
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) { if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
return Pair.of(false, Sets.newHashSet()); return Pair.of(false, Sets.newHashSet());
} }
@@ -260,8 +265,8 @@ public class PluginManager {
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) { private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, SchemaMapInfo mapInfo) {
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos(); Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
if (plugin.isContainsAllDataSet()) { if (plugin.isContainsAllDataSet()) {
return Sets.newHashSet(plugin.getDefaultMode()); return Sets.newHashSet(plugin.getDefaultMode());
} }

View File

@@ -7,15 +7,20 @@ import com.tencent.supersonic.chat.server.plugin.PluginManager;
import com.tencent.supersonic.chat.server.plugin.PluginParseResult; import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult; import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.HashMap; import java.util.HashMap;
@@ -48,9 +53,12 @@ public abstract class PluginRecognizer {
if (plugin.isContainsAllDataSet()) { if (plugin.isContainsAllDataSet()) {
dataSetIds = Sets.newHashSet(-1L); dataSetIds = Sets.newHashSet(-1L);
} }
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
for (Long dataSetId : dataSetIds) { for (Long dataSetId : dataSetIds) {
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin, SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
parseContext, pluginRecallResult.getDistance()); parseContext, schemaMapInfo, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(plugin.getType()); semanticParseInfo.setQueryMode(plugin.getType());
semanticParseInfo.setScore(pluginRecallResult.getScore()); semanticParseInfo.setScore(pluginRecallResult.getScore());
parseResp.getSelectedParses().add(semanticParseInfo); parseResp.getSelectedParses().add(semanticParseInfo);
@@ -62,9 +70,8 @@ public abstract class PluginRecognizer {
} }
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin, protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
ParseContext parseContext, double distance) { ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
List<SchemaElementMatch> schemaElementMatches = List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
parseContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = parseContext.getQueryFilters(); QueryFilters queryFilters = parseContext.getQueryFilters();
if (schemaElementMatches == null) { if (schemaElementMatches == null) {
schemaElementMatches = Lists.newArrayList(); schemaElementMatches = Lists.newArrayList();

View File

@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.pojo;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.common.pojo.User; import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters; import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
import lombok.Data; import lombok.Data;
@@ -14,7 +13,6 @@ public class ParseContext {
private Integer chatId; private Integer chatId;
private QueryFilters queryFilters; private QueryFilters queryFilters;
private boolean saveAnswer = true; private boolean saveAnswer = true;
private SchemaMapInfo mapInfo;
private boolean disableLLM = false; private boolean disableLLM = false;
public boolean enableNL2SQL() { public boolean enableNL2SQL() {

View File

@@ -38,7 +38,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq; import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.api.pojo.response.SearchResult; import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
@@ -90,11 +89,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired @Autowired
private ChatModelService chatModelService; private ChatModelService chatModelService;
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers(); private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors(); private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
private List<ParseResultProcessor> parseResultProcessors = private final List<ParseResultProcessor> parseResultProcessors =
ComponentFactory.getParseProcessors(); ComponentFactory.getParseProcessors();
private List<ExecuteResultProcessor> executeResultProcessors = private final List<ExecuteResultProcessor> executeResultProcessors =
ComponentFactory.getExecuteProcessors(); ComponentFactory.getExecuteProcessors();
@Override @Override
@@ -104,7 +103,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (!agent.enableSearch()) { if (!agent.enableSearch()) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
return chatLayerService.retrieve(queryNLReq); return chatLayerService.retrieve(queryNLReq);
} }
@@ -113,13 +112,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText()); ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
chatManageService.createChatQuery(chatParseReq, parseResp); chatManageService.createChatQuery(chatParseReq, parseResp);
ParseContext parseContext = buildParseContext(chatParseReq); ParseContext parseContext = buildParseContext(chatParseReq);
supplyMapInfo(parseContext);
for (ChatQueryParser chatQueryParser : chatQueryParsers) { for (ChatQueryParser parser : chatQueryParsers) {
chatQueryParser.parse(parseContext, parseResp); parser.parse(parseContext, parseResp);
} }
for (ParseResultProcessor processor : parseResultProcessors) { for (ParseResultProcessor processor : parseResultProcessors) {
processor.process(parseContext, parseResp); processor.process(parseContext, parseResp);
} }
chatParseReq.setQueryText(parseContext.getQueryText()); chatParseReq.setQueryText(parseContext.getQueryText());
chatManageService.batchAddParse(chatParseReq, parseResp); chatManageService.batchAddParse(chatParseReq, parseResp);
chatManageService.updateParseCostTime(parseResp); chatManageService.updateParseCostTime(parseResp);
@@ -175,12 +175,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return parseContext; return parseContext;
} }
private void supplyMapInfo(ParseContext parseContext) {
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp mapResp = chatLayerService.map(queryNLReq);
parseContext.setMapInfo(mapResp.getMapInfo());
}
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ExecuteContext executeContext = new ExecuteContext(); ExecuteContext executeContext = new ExecuteContext();
BeanMapper.mapper(chatExecuteReq, executeContext); BeanMapper.mapper(chatExecuteReq, executeContext);
@@ -197,7 +191,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
Integer parseId = chatQueryDataReq.getParseId(); Integer parseId = chatQueryDataReq.getParseId();
SemanticParseInfo parseInfo = SemanticParseInfo parseInfo =
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId); chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq); mergeParseInfo(parseInfo, chatQueryDataReq);
DataSetSchema dataSetSchema = DataSetSchema dataSetSchema =
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId()); semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
@@ -494,10 +488,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
}); });
} }
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo, private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
ChatQueryDataReq queryData) {
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo; return;
} }
if (!CollectionUtils.isEmpty(queryData.getDimensions())) { if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
parseInfo.setDimensions(queryData.getDimensions()); parseInfo.setDimensions(queryData.getDimensions());
@@ -515,7 +508,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
parseInfo.setDateInfo(queryData.getDateInfo()); parseInfo.setDateInfo(queryData.getDateInfo());
} }
parseInfo.setSqlInfo(new SqlInfo()); parseInfo.setSqlInfo(new SqlInfo());
return parseInfo;
} }
private void validFilter(Set<QueryFilter> filters) { private void validFilter(Set<QueryFilter> filters) {

View File

@@ -12,11 +12,11 @@ import java.util.Objects;
public class QueryReqConverter { public class QueryReqConverter {
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) { public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
return buildText2SqlQueryReq(parseContext, null); return buildQueryNLReq(parseContext, null);
} }
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) { public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) {
QueryNLReq queryNLReq = new QueryNLReq(); QueryNLReq queryNLReq = new QueryNLReq();
BeanMapper.mapper(parseContext, queryNLReq); BeanMapper.mapper(parseContext, queryNLReq);
Agent agent = parseContext.getAgent(); Agent agent = parseContext.getAgent();

View File

@@ -78,8 +78,7 @@ public class QueryTypeParser implements SemanticParser {
} }
private static List<String> filterByTimeFields(List<String> whereFields) { private static List<String> filterByTimeFields(List<String> whereFields) {
return whereFields.stream() return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }

View File

@@ -89,7 +89,7 @@ public class S2ChatLayerService implements ChatLayerService {
public ParseResp parse(QueryNLReq queryNLReq) { public ParseResp parse(QueryNLReq queryNLReq) {
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText()); ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
chatWorkflowEngine.execute(queryCtx, parseResult); chatWorkflowEngine.start(queryCtx, parseResult);
return parseResult; return parseResult;
} }

View File

@@ -36,7 +36,7 @@ public class ChatWorkflowEngine {
ComponentFactory.getSemanticCorrectors(); ComponentFactory.getSemanticCorrectors();
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors(); private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) { public void start(ChatQueryContext queryCtx, ParseResp parseResult) {
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING); queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) { while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
switch (queryCtx.getChatWorkflowState()) { switch (queryCtx.getChatWorkflowState()) {
@@ -122,8 +122,8 @@ public class ChatWorkflowEngine {
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx)); resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
} }
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) { private void performTranslating(ChatQueryContext queryCtx, ParseResp parseResult) {
List<SemanticParseInfo> semanticParseInfos = chatQueryContext.getCandidateQueries().stream() List<SemanticParseInfo> semanticParseInfos = queryCtx.getCandidateQueries().stream()
.map(SemanticQuery::getParseInfo).collect(Collectors.toList()); .map(SemanticQuery::getParseInfo).collect(Collectors.toList());
List<String> errorMsg = new ArrayList<>(); List<String> errorMsg = new ArrayList<>();
if (StringUtils.isNotBlank(parseResult.getErrorMsg())) { if (StringUtils.isNotBlank(parseResult.getErrorMsg())) {
@@ -140,7 +140,7 @@ public class ChatWorkflowEngine {
SemanticLayerService queryService = SemanticLayerService queryService =
ContextUtils.getBean(SemanticLayerService.class); ContextUtils.getBean(SemanticLayerService.class);
SemanticTranslateResp explain = SemanticTranslateResp explain =
queryService.translate(semanticQueryReq, chatQueryContext.getUser()); queryService.translate(semanticQueryReq, queryCtx.getUser());
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
if (StringUtils.isNotBlank(explain.getErrMsg())) { if (StringUtils.isNotBlank(explain.getErrMsg())) {
errorMsg.add(explain.getErrMsg()); errorMsg.add(explain.getErrMsg());