mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
[improvement][project]Remove unnecessary SchemaMapInfo from ParseContext.
This commit is contained in:
@@ -100,7 +100,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (!parseContext.isDisableLLM()) {
|
if (!parseContext.isDisableLLM()) {
|
||||||
processMultiTurn(parseContext);
|
processMultiTurn(parseContext);
|
||||||
}
|
}
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext, chatCtx);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext, chatCtx);
|
||||||
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
addDynamicExemplars(parseContext.getAgent().getId(), queryNLReq);
|
||||||
|
|
||||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
@@ -179,11 +179,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
|
|
||||||
// derive mapping result of current question and parsing result of last question.
|
// derive mapping result of current question and parsing result of last question.
|
||||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
MapResp currentMapResult = chatLayerService.map(queryNLReq);
|
||||||
|
|
||||||
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
List<QueryResp> historyQueries = getHistoryQueries(parseContext.getChatId(), 1);
|
||||||
if (historyQueries.size() == 0) {
|
if (historyQueries.isEmpty()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
QueryResp lastQuery = historyQueries.get(0);
|
QueryResp lastQuery = historyQueries.get(0);
|
||||||
@@ -209,9 +209,6 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
String rewrittenQuery = response.content().text();
|
String rewrittenQuery = response.content().text();
|
||||||
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
|
keyPipelineLog.info("QueryRewrite modelReq:\n{} \nmodelResp:\n{}", prompt.text(), response);
|
||||||
parseContext.setQueryText(rewrittenQuery);
|
parseContext.setQueryText(rewrittenQuery);
|
||||||
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
|
||||||
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
|
|
||||||
parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo());
|
|
||||||
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
|
log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
|
||||||
currentMapResult.getQueryText(), rewrittenQuery);
|
currentMapResult.getQueryText(), rewrittenQuery);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.server.plugin.event.PluginDelEvent;
|
|||||||
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
import com.tencent.supersonic.chat.server.plugin.event.PluginUpdateEvent;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||||
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.service.EmbeddingService;
|
import com.tencent.supersonic.common.service.EmbeddingService;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
@@ -20,6 +21,8 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
|
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||||
import dev.langchain4j.data.segment.TextSegment;
|
import dev.langchain4j.data.segment.TextSegment;
|
||||||
import dev.langchain4j.store.embedding.Retrieval;
|
import dev.langchain4j.store.embedding.Retrieval;
|
||||||
import dev.langchain4j.store.embedding.RetrieveQuery;
|
import dev.langchain4j.store.embedding.RetrieveQuery;
|
||||||
@@ -193,8 +196,10 @@ public class PluginManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
public static Pair<Boolean, Set<Long>> resolve(ChatPlugin plugin, ParseContext parseContext) {
|
||||||
SchemaMapInfo schemaMapInfo = parseContext.getMapInfo();
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
|
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
|
||||||
|
Set<Long> pluginMatchedDataSet = getPluginMatchedDataSet(plugin, schemaMapInfo);
|
||||||
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
if (CollectionUtils.isEmpty(pluginMatchedDataSet) && !plugin.isContainsAllDataSet()) {
|
||||||
return Pair.of(false, Sets.newHashSet());
|
return Pair.of(false, Sets.newHashSet());
|
||||||
}
|
}
|
||||||
@@ -260,8 +265,8 @@ public class PluginManager {
|
|||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, ParseContext parseContext) {
|
private static Set<Long> getPluginMatchedDataSet(ChatPlugin plugin, SchemaMapInfo mapInfo) {
|
||||||
Set<Long> matchedDataSets = parseContext.getMapInfo().getMatchedDataSetInfos();
|
Set<Long> matchedDataSets = mapInfo.getMatchedDataSetInfos();
|
||||||
if (plugin.isContainsAllDataSet()) {
|
if (plugin.isContainsAllDataSet()) {
|
||||||
return Sets.newHashSet(plugin.getDefaultMode());
|
return Sets.newHashSet(plugin.getDefaultMode());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,15 +7,20 @@ import com.tencent.supersonic.chat.server.plugin.PluginManager;
|
|||||||
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
import com.tencent.supersonic.chat.server.plugin.PluginParseResult;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
import com.tencent.supersonic.chat.server.plugin.PluginRecallResult;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.pojo.Constants;
|
import com.tencent.supersonic.common.pojo.Constants;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
|
import com.tencent.supersonic.headless.server.facade.service.ChatLayerService;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
@@ -48,9 +53,12 @@ public abstract class PluginRecognizer {
|
|||||||
if (plugin.isContainsAllDataSet()) {
|
if (plugin.isContainsAllDataSet()) {
|
||||||
dataSetIds = Sets.newHashSet(-1L);
|
dataSetIds = Sets.newHashSet(-1L);
|
||||||
}
|
}
|
||||||
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
|
SchemaMapInfo schemaMapInfo = chatLayerService.map(queryNLReq).getMapInfo();
|
||||||
for (Long dataSetId : dataSetIds) {
|
for (Long dataSetId : dataSetIds) {
|
||||||
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
|
||||||
parseContext, pluginRecallResult.getDistance());
|
parseContext, schemaMapInfo, pluginRecallResult.getDistance());
|
||||||
semanticParseInfo.setQueryMode(plugin.getType());
|
semanticParseInfo.setQueryMode(plugin.getType());
|
||||||
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
semanticParseInfo.setScore(pluginRecallResult.getScore());
|
||||||
parseResp.getSelectedParses().add(semanticParseInfo);
|
parseResp.getSelectedParses().add(semanticParseInfo);
|
||||||
@@ -62,9 +70,8 @@ public abstract class PluginRecognizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
|
||||||
ParseContext parseContext, double distance) {
|
ParseContext parseContext, SchemaMapInfo mapInfo, double distance) {
|
||||||
List<SchemaElementMatch> schemaElementMatches =
|
List<SchemaElementMatch> schemaElementMatches = mapInfo.getMatchedElements(dataSetId);
|
||||||
parseContext.getMapInfo().getMatchedElements(dataSetId);
|
|
||||||
QueryFilters queryFilters = parseContext.getQueryFilters();
|
QueryFilters queryFilters = parseContext.getQueryFilters();
|
||||||
if (schemaElementMatches == null) {
|
if (schemaElementMatches == null) {
|
||||||
schemaElementMatches = Lists.newArrayList();
|
schemaElementMatches = Lists.newArrayList();
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package com.tencent.supersonic.chat.server.pojo;
|
|||||||
|
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.common.pojo.User;
|
import com.tencent.supersonic.common.pojo.User;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilters;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
|
||||||
@@ -14,7 +13,6 @@ public class ParseContext {
|
|||||||
private Integer chatId;
|
private Integer chatId;
|
||||||
private QueryFilters queryFilters;
|
private QueryFilters queryFilters;
|
||||||
private boolean saveAnswer = true;
|
private boolean saveAnswer = true;
|
||||||
private SchemaMapInfo mapInfo;
|
|
||||||
private boolean disableLLM = false;
|
private boolean disableLLM = false;
|
||||||
|
|
||||||
public boolean enableNL2SQL() {
|
public boolean enableNL2SQL() {
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ import com.tencent.supersonic.headless.api.pojo.request.DimensionValueReq;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
import com.tencent.supersonic.headless.api.pojo.response.SearchResult;
|
||||||
@@ -90,11 +89,11 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ChatModelService chatModelService;
|
private ChatModelService chatModelService;
|
||||||
|
|
||||||
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
private final List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
|
||||||
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
private final List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
|
||||||
private List<ParseResultProcessor> parseResultProcessors =
|
private final List<ParseResultProcessor> parseResultProcessors =
|
||||||
ComponentFactory.getParseProcessors();
|
ComponentFactory.getParseProcessors();
|
||||||
private List<ExecuteResultProcessor> executeResultProcessors =
|
private final List<ExecuteResultProcessor> executeResultProcessors =
|
||||||
ComponentFactory.getExecuteProcessors();
|
ComponentFactory.getExecuteProcessors();
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -104,7 +103,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
if (!agent.enableSearch()) {
|
if (!agent.enableSearch()) {
|
||||||
return Lists.newArrayList();
|
return Lists.newArrayList();
|
||||||
}
|
}
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||||
return chatLayerService.retrieve(queryNLReq);
|
return chatLayerService.retrieve(queryNLReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,13 +112,14 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
ParseResp parseResp = new ParseResp(chatParseReq.getQueryText());
|
||||||
chatManageService.createChatQuery(chatParseReq, parseResp);
|
chatManageService.createChatQuery(chatParseReq, parseResp);
|
||||||
ParseContext parseContext = buildParseContext(chatParseReq);
|
ParseContext parseContext = buildParseContext(chatParseReq);
|
||||||
supplyMapInfo(parseContext);
|
|
||||||
for (ChatQueryParser chatQueryParser : chatQueryParsers) {
|
for (ChatQueryParser parser : chatQueryParsers) {
|
||||||
chatQueryParser.parse(parseContext, parseResp);
|
parser.parse(parseContext, parseResp);
|
||||||
}
|
}
|
||||||
for (ParseResultProcessor processor : parseResultProcessors) {
|
for (ParseResultProcessor processor : parseResultProcessors) {
|
||||||
processor.process(parseContext, parseResp);
|
processor.process(parseContext, parseResp);
|
||||||
}
|
}
|
||||||
|
|
||||||
chatParseReq.setQueryText(parseContext.getQueryText());
|
chatParseReq.setQueryText(parseContext.getQueryText());
|
||||||
chatManageService.batchAddParse(chatParseReq, parseResp);
|
chatManageService.batchAddParse(chatParseReq, parseResp);
|
||||||
chatManageService.updateParseCostTime(parseResp);
|
chatManageService.updateParseCostTime(parseResp);
|
||||||
@@ -175,12 +175,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
return parseContext;
|
return parseContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void supplyMapInfo(ParseContext parseContext) {
|
|
||||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
|
||||||
MapResp mapResp = chatLayerService.map(queryNLReq);
|
|
||||||
parseContext.setMapInfo(mapResp.getMapInfo());
|
|
||||||
}
|
|
||||||
|
|
||||||
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
|
||||||
ExecuteContext executeContext = new ExecuteContext();
|
ExecuteContext executeContext = new ExecuteContext();
|
||||||
BeanMapper.mapper(chatExecuteReq, executeContext);
|
BeanMapper.mapper(chatExecuteReq, executeContext);
|
||||||
@@ -197,7 +191,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
Integer parseId = chatQueryDataReq.getParseId();
|
Integer parseId = chatQueryDataReq.getParseId();
|
||||||
SemanticParseInfo parseInfo =
|
SemanticParseInfo parseInfo =
|
||||||
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
|
||||||
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
|
mergeParseInfo(parseInfo, chatQueryDataReq);
|
||||||
DataSetSchema dataSetSchema =
|
DataSetSchema dataSetSchema =
|
||||||
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());
|
||||||
|
|
||||||
@@ -494,10 +488,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
|
private void mergeParseInfo(SemanticParseInfo parseInfo, ChatQueryDataReq queryData) {
|
||||||
ChatQueryDataReq queryData) {
|
|
||||||
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
|
||||||
return parseInfo;
|
return;
|
||||||
}
|
}
|
||||||
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
|
if (!CollectionUtils.isEmpty(queryData.getDimensions())) {
|
||||||
parseInfo.setDimensions(queryData.getDimensions());
|
parseInfo.setDimensions(queryData.getDimensions());
|
||||||
@@ -515,7 +508,6 @@ public class ChatQueryServiceImpl implements ChatQueryService {
|
|||||||
parseInfo.setDateInfo(queryData.getDateInfo());
|
parseInfo.setDateInfo(queryData.getDateInfo());
|
||||||
}
|
}
|
||||||
parseInfo.setSqlInfo(new SqlInfo());
|
parseInfo.setSqlInfo(new SqlInfo());
|
||||||
return parseInfo;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void validFilter(Set<QueryFilter> filters) {
|
private void validFilter(Set<QueryFilter> filters) {
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ import java.util.Objects;
|
|||||||
|
|
||||||
public class QueryReqConverter {
|
public class QueryReqConverter {
|
||||||
|
|
||||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext) {
|
public static QueryNLReq buildQueryNLReq(ParseContext parseContext) {
|
||||||
return buildText2SqlQueryReq(parseContext, null);
|
return buildQueryNLReq(parseContext, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static QueryNLReq buildText2SqlQueryReq(ParseContext parseContext, ChatContext chatCtx) {
|
public static QueryNLReq buildQueryNLReq(ParseContext parseContext, ChatContext chatCtx) {
|
||||||
QueryNLReq queryNLReq = new QueryNLReq();
|
QueryNLReq queryNLReq = new QueryNLReq();
|
||||||
BeanMapper.mapper(parseContext, queryNLReq);
|
BeanMapper.mapper(parseContext, queryNLReq);
|
||||||
Agent agent = parseContext.getAgent();
|
Agent agent = parseContext.getAgent();
|
||||||
|
|||||||
@@ -78,8 +78,7 @@ public class QueryTypeParser implements SemanticParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private static List<String> filterByTimeFields(List<String> whereFields) {
|
private static List<String> filterByTimeFields(List<String> whereFields) {
|
||||||
return whereFields.stream()
|
return whereFields.stream().filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
||||||
.filter(field -> !TimeDimensionEnum.containsTimeDimension(field))
|
|
||||||
.collect(Collectors.toList());
|
.collect(Collectors.toList());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ public class S2ChatLayerService implements ChatLayerService {
|
|||||||
public ParseResp parse(QueryNLReq queryNLReq) {
|
public ParseResp parse(QueryNLReq queryNLReq) {
|
||||||
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
|
ParseResp parseResult = new ParseResp(queryNLReq.getQueryText());
|
||||||
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq);
|
||||||
chatWorkflowEngine.execute(queryCtx, parseResult);
|
chatWorkflowEngine.start(queryCtx, parseResult);
|
||||||
return parseResult;
|
return parseResult;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ public class ChatWorkflowEngine {
|
|||||||
ComponentFactory.getSemanticCorrectors();
|
ComponentFactory.getSemanticCorrectors();
|
||||||
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
private final List<ResultProcessor> resultProcessors = ComponentFactory.getResultProcessors();
|
||||||
|
|
||||||
public void execute(ChatQueryContext queryCtx, ParseResp parseResult) {
|
public void start(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||||
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
queryCtx.setChatWorkflowState(ChatWorkflowState.MAPPING);
|
||||||
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
|
while (queryCtx.getChatWorkflowState() != ChatWorkflowState.FINISHED) {
|
||||||
switch (queryCtx.getChatWorkflowState()) {
|
switch (queryCtx.getChatWorkflowState()) {
|
||||||
@@ -122,8 +122,8 @@ public class ChatWorkflowEngine {
|
|||||||
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
|
resultProcessors.forEach(processor -> processor.process(parseResult, queryCtx));
|
||||||
}
|
}
|
||||||
|
|
||||||
private void performTranslating(ChatQueryContext chatQueryContext, ParseResp parseResult) {
|
private void performTranslating(ChatQueryContext queryCtx, ParseResp parseResult) {
|
||||||
List<SemanticParseInfo> semanticParseInfos = chatQueryContext.getCandidateQueries().stream()
|
List<SemanticParseInfo> semanticParseInfos = queryCtx.getCandidateQueries().stream()
|
||||||
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
.map(SemanticQuery::getParseInfo).collect(Collectors.toList());
|
||||||
List<String> errorMsg = new ArrayList<>();
|
List<String> errorMsg = new ArrayList<>();
|
||||||
if (StringUtils.isNotBlank(parseResult.getErrorMsg())) {
|
if (StringUtils.isNotBlank(parseResult.getErrorMsg())) {
|
||||||
@@ -140,7 +140,7 @@ public class ChatWorkflowEngine {
|
|||||||
SemanticLayerService queryService =
|
SemanticLayerService queryService =
|
||||||
ContextUtils.getBean(SemanticLayerService.class);
|
ContextUtils.getBean(SemanticLayerService.class);
|
||||||
SemanticTranslateResp explain =
|
SemanticTranslateResp explain =
|
||||||
queryService.translate(semanticQueryReq, chatQueryContext.getUser());
|
queryService.translate(semanticQueryReq, queryCtx.getUser());
|
||||||
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL());
|
||||||
if (StringUtils.isNotBlank(explain.getErrMsg())) {
|
if (StringUtils.isNotBlank(explain.getErrMsg())) {
|
||||||
errorMsg.add(explain.getErrMsg());
|
errorMsg.add(explain.getErrMsg());
|
||||||
|
|||||||
Reference in New Issue
Block a user