mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][project]Remove unnecessary SchemaMapInfo from ParseContext.
This commit is contained in:
@@ -100,7 +100,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
if (!parseContext.isDisableLLM()) {
|
||||
processMultiTurn(parseContext);
|
||||
}
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx);
|
||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
@@ -179,11 +179,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
||||
|
||||
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
||||
if (historyQueries.size() == 0) {
|
||||
if (historyQueries.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
QueryResp lastQuery = historyQueries.get(0);
|
||||
@@ -209,9 +209,6 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
String rewrittenQuery = response.content().text();
|
||||
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
|
||||
parseContext.setQueryText(rewrittenQuery);
|
||||
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
|
||||
parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo());
|
||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
|
||||
currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
@@ -20,6 +21,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import dev.langchain4j.data.segment.TextSegment;
|
||||
import dev.langchain4j.store.embedding.Retrieval;
|
||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||
@@ -193,8 +196,10 @@ public class PluginManager {
|
||||
}
|
||||
|
||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
||||
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
|
||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, schemaMapInfo);
|
||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||
return Pair.of(false, Sets.newHashSet());
|
||||
}
|
||||
@@ -260,8 +265,8 @@ public class PluginManager {
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
|
||||
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
|
||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, SchemaMapInfo mapInfo) {
|
||||
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
return Sets.newHashSet(plugin.getDefaultMode());
|
||||
}
|
||||
|
||||
@@ -7,15 +7,20 @@ import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.HashMap;
|
||||
@@ -48,9 +53,12 @@ public abstract class PluginRecognizer {
|
||||
if (plugin.isContainsAllDataSet()) {
|
||||
dataSetIds = Sets.newHashSet(-1L);
|
||||
}
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
|
||||
for (Long dataSetId : dataSetIds) {
|
||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||
parseContext, pluginRecallResult.getDistance());
|
||||
parseContext, schemaMapInfo, pluginRecallResult.getDistance());
|
||||
semanticParseInfo.setQueryMode(plugin.getType());
|
||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||
@@ -62,9 +70,8 @@ public abstract class PluginRecognizer {
|
||||
}
|
||||
|
||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||
ParseContext parseContext, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches =
|
||||
parseContext.getMapInfo().getMatchedElements(dataSetId);
|
||||
ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
|
||||
List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||
if (schemaElementMatches == null) {
|
||||
schemaElementMatches = Lists.newArrayList();
|
||||
|
||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.pojo;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.common.pojo.User;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -14,7 +13,6 @@ public class ParseContext {
|
||||
private Integer chatId;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private SchemaMapInfo mapInfo;
|
||||
private boolean disableLLM = false;
|
||||
|
||||
public boolean enableNL2SQL() {
|
||||
|
||||
@@ -38,7 +38,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||
@@ -90,11 +89,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
@Autowired
|
||||
private ChatModelService chatModelService;
|
||||
|
||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
private List<ParseResultProcessor> parseResultProcessors =
|
||||
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||
private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||
private final List<ParseResultProcessor> parseResultProcessors =
|
||||
ComponentFactory.getParseProcessors();
|
||||
private List<ExecuteResultProcessor> executeResultProcessors =
|
||||
private final List<ExecuteResultProcessor> executeResultProcessors =
|
||||
ComponentFactory.getExecuteProcessors();
|
||||
|
||||
@Override
|
||||
@@ -104,7 +103,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
if (!agent.enableSearch()) {
|
||||
return Lists.newArrayList();
|
||||
}
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
return chatLayerService.retrieve(queryNLReq);
|
||||
}
|
||||
|
||||
@@ -113,13 +112,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||
supplyMapInfo(parseContext);
|
||||
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
|
||||
chatQueryParser.parse(parseContext, parseResp);
|
||||
|
||||
for (ChatQueryParser parser : chatQueryParsers) {
|
||||
parser.parse(parseContext, parseResp);
|
||||
}
|
||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||
processor.process(parseContext, parseResp);
|
||||
}
|
||||
|
||||
chatParseReq.setQueryText(parseContext.getQueryText());
|
||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||
chatManageService.updateParseCostTime(parseResp);
|
||||
@@ -175,12 +175,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
return parseContext;
|
||||
}
|
||||
|
||||
private void supplyMapInfo(ParseContext parseContext) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
MapResp mapResp = chatLayerService.map(queryNLReq);
|
||||
parseContext.setMapInfo(mapResp.getMapInfo());
|
||||
}
|
||||
|
||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||
ExecuteContext executeContext = new ExecuteContext();
|
||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||
@@ -197,7 +191,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
Integer parseId = chatQueryDataReq.getParseId();
|
||||
SemanticParseInfo parseInfo =
|
||||
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
||||
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
|
||||
mergeParseInfo(parseInfo, chatQueryDataReq);
|
||||
DataSetSchema dataSetSchema =
|
||||
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||
|
||||
@@ -494,10 +488,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
});
|
||||
}
|
||||
|
||||
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
|
||||
ChatQueryDataReq queryData) {
|
||||
private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
|
||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||
return parseInfo;
|
||||
return;
|
||||
}
|
||||
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
|
||||
parseInfo.setDimensions(queryData.getDimensions());
|
||||
@@ -515,7 +508,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||
}
|
||||
parseInfo.setSqlInfo(new SqlInfo());
|
||||
return parseInfo;
|
||||
}
|
||||
|
||||
private void validFilter(Set<QueryFilter> filters) {
|
||||
|
||||
@@ -12,11 +12,11 @@ import java.util.Objects;
|
||||
|
||||
public class QueryReqConverter {
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
|
||||
return buildText2SqlQueryReq(parseContext, null);
|
||||
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
|
||||
return buildQueryNLReq(parseContext, null);
|
||||
}
|
||||
|
||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||
public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||
QueryNLReq queryNLReq = new QueryNLReq();
|
||||
BeanMapper.mapper(parseContext, queryNLReq);
|
||||
Agent agent = parseContext.getAgent();
|
||||
|
||||
@@ -78,8 +78,7 @@ public class QueryTypeParser implements SemanticParser {
|
||||
}
|
||||
|
||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||
return whereFields.stream()
|
||||
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
||||
public ParseResp parse(QueryNLReq queryNLReq) {
|
||||
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
|
||||
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
||||
chatWorkflowEngine.execute(queryCtx, parseResult);
|
||||
chatWorkflowEngine.start(queryCtx, parseResult);
|
||||
return parseResult;
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ public class ChatWorkflowEngine {
|
||||
ComponentFactory.getSemanticCorrectors();
|
||||
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||
|
||||
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
public void start(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
||||
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
|
||||
switch (queryCtx.getChatWorkflowState()) {
|
||||
@@ -122,8 +122,8 @@ public class ChatWorkflowEngine {
|
||||
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
|
||||
}
|
||||
|
||||
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> semanticParseInfos = chatQueryContext.getCandidateQueries().stream()
|
||||
private void performTranslating(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||
List<SemanticParseInfo> semanticParseInfos = queryCtx.getCandidateQueries().stream()
|
||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||
List<String> errorMsg = new ArrayList<>();
|
||||
if (StringUtils.isNotBlank(parseResult.getErrorMsg())) {
|
||||
@@ -140,7 +140,7 @@ public class ChatWorkflowEngine {
|
||||
SemanticLayerService queryService =
|
||||
ContextUtils.getBean(SemanticLayerService.class);
|
||||
SemanticTranslateResp explain =
|
||||
queryService.translate(semanticQueryReq, chatQueryContext.getUser());
|
||||
queryService.translate(semanticQueryReq, queryCtx.getUser());
|
||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||
if (StringUtils.isNotBlank(explain.getErrMsg())) {
|
||||
errorMsg.add(explain.getErrMsg());
|
||||
|
||||
Reference in New Issue
Block a user