diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index aaaf60291..378a4e994 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -3,9 +3,7 @@ package com.tencent.supersonic.chat.server.parser; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; -import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ParseContext; -import com.tencent.supersonic.chat.server.service.ChatContextService; import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.common.config.EmbeddingConfig; @@ -16,12 +14,10 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; -import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; @@ -38,9 +34,14 @@ import dev.langchain4j.provider.ModelProvider; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.util.CollectionUtils; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.stream.Collectors; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; @@ -97,6 +98,7 @@ public class NL2SQLParser implements ChatQueryParser { queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); doParse(queryNLReq, parseResp); } + // for one dataset select the most suitable parses List sortedParses = parseResp.getSelectedParses().stream() .sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1) .collect(Collectors.toList()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index dda36ae8a..1d585ff99 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -1,14 +1,12 @@ package com.tencent.supersonic.headless.chat; import com.fasterxml.jackson.annotation.JsonIgnore; -import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.headless.api.pojo.DataSetSchema; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticSchema; import com.tencent.supersonic.headless.api.pojo.enums.ChatWorkflowState; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; -import com.tencent.supersonic.headless.chat.parser.ParserConfig; import com.tencent.supersonic.headless.chat.query.SemanticQuery; import lombok.Data; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java deleted file mode 100644 index 0996bccf2..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java +++ /dev/null @@ -1,137 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.rule; - -import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.api.pojo.SchemaElementType; -import com.tencent.supersonic.headless.chat.ChatQueryContext; -import com.tencent.supersonic.headless.chat.parser.SemanticParser; -import com.tencent.supersonic.headless.chat.query.QueryManager; -import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; -import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQuery; -import com.tencent.supersonic.headless.chat.query.rule.metric.MetricIdQuery; -import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery; -import com.tencent.supersonic.headless.chat.query.rule.metric.MetricSemanticQuery; -import lombok.extern.slf4j.Slf4j; - -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/** - * ContextInheritParser tries to inherit certain schema elements from context so that in multi-turn - * conversations users don't need to mention some keyword repeatedly. - */ -@Slf4j -public class ContextInheritParser implements SemanticParser { - - private static final Map> MUTUAL_EXCLUSIVE_MAP = - Stream.of( - new AbstractMap.SimpleEntry<>(SchemaElementType.METRIC, - Arrays.asList(SchemaElementType.METRIC)), - new AbstractMap.SimpleEntry<>(SchemaElementType.DIMENSION, - Arrays.asList(SchemaElementType.DIMENSION, SchemaElementType.VALUE)), - new AbstractMap.SimpleEntry<>(SchemaElementType.VALUE, - Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)), - new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, - Arrays.asList(SchemaElementType.ENTITY)), - new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, - Arrays.asList(SchemaElementType.DATASET)), - new AbstractMap.SimpleEntry<>(SchemaElementType.ID, - Arrays.asList(SchemaElementType.ID))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - - @Override - public void parse(ChatQueryContext chatQueryContext) { - if (!shouldInherit(chatQueryContext)) { - return; - } - Long dataSetId = getMatchedDataSet(chatQueryContext); - if (dataSetId == null) { - return; - } - - List elementMatches = - chatQueryContext.getMapInfo().getMatchedElements(dataSetId); - - List matchesToInherit = new ArrayList<>(); - for (SchemaElementMatch match : chatQueryContext.getRequest().getContextParseInfo() - .getElementMatches()) { - SchemaElementType matchType = match.getElement().getType(); - // mutual exclusive element types should not be inherited - RuleSemanticQuery ruleQuery = QueryManager.getRuleQuery( - chatQueryContext.getRequest().getContextParseInfo().getQueryMode()); - if (!containsTypes(elementMatches, matchType, ruleQuery)) { - match.setInherited(true); - matchesToInherit.add(match); - } - } - elementMatches.addAll(matchesToInherit); - - List queries = - RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext); - for (RuleSemanticQuery query : queries) { - query.fillParseInfo(chatQueryContext); - if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), - chatQueryContext)) { - continue; - } - chatQueryContext.getCandidateQueries().add(query); - } - } - - private boolean existSameQuery(Long dataSetId, String queryMode, - ChatQueryContext chatQueryContext) { - for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) { - if (semanticQuery.getQueryMode().equals(queryMode) - && semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) { - return true; - } - } - return false; - } - - private boolean containsTypes(List matches, SchemaElementType matchType, - RuleSemanticQuery ruleQuery) { - List types = MUTUAL_EXCLUSIVE_MAP.get(matchType); - - return matches.stream().anyMatch(m -> { - SchemaElementType type = m.getElement().getType(); - if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery - && !(ruleQuery instanceof MetricIdQuery)) { - return types.contains(type); - } - return type.equals(matchType); - }); - } - - protected boolean shouldInherit(ChatQueryContext chatQueryContext) { - // if candidates only have MetricModel mode, count in context - List metricModelQueries = - chatQueryContext.getCandidateQueries().stream() - .filter(query -> query instanceof MetricModelQuery - || query instanceof DetailDimensionQuery) - .collect(Collectors.toList()); - return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size(); - } - - protected Long getMatchedDataSet(ChatQueryContext chatQueryContext) { - if (Objects.isNull(chatQueryContext) - || Objects.isNull(chatQueryContext.getRequest().getContextParseInfo()) - || Objects.isNull( - chatQueryContext.getRequest().getContextParseInfo().getDataSetId())) { - return null; - } - Long dataSetId = chatQueryContext.getRequest().getContextParseInfo().getDataSetId(); - Set queryDataSets = chatQueryContext.getMapInfo().getMatchedDataSetInfos(); - if (queryDataSets.contains(dataSetId)) { - return dataSetId; - } - return dataSetId; - } -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index 19a7ce26f..c8be76c2b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -19,8 +19,8 @@ import java.util.List; @Slf4j public class RuleSqlParser implements SemanticParser { - private static final List auxiliaryParsers = Arrays - .asList(new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser()); + private static final List auxiliaryParsers = + Arrays.asList(new TimeRangeParser(), new AggregateTypeParser()); @Override public void parse(ChatQueryContext chatQueryContext) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java index d7ba6466e..0dd6d59c1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java @@ -5,9 +5,7 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.parser.SemanticParser; -import com.tencent.supersonic.headless.chat.query.QueryManager; import com.tencent.supersonic.headless.chat.query.SemanticQuery; -import com.tencent.supersonic.headless.chat.query.rule.RuleSemanticQuery; import com.xkzhangsan.time.nlp.TimeNLP; import com.xkzhangsan.time.nlp.TimeNLPUtil; import lombok.extern.slf4j.Slf4j; @@ -38,6 +36,10 @@ public class TimeRangeParser implements SemanticParser { @Override public void parse(ChatQueryContext queryContext) { + if (queryContext.getCandidateQueries().isEmpty()) { + return; + } + String queryText = queryContext.getRequest().getQueryText(); DateConf dateConf = parseRecent(queryText); if (dateConf == null) { @@ -46,34 +48,18 @@ public class TimeRangeParser implements SemanticParser { if (dateConf == null) { dateConf = parseDateCN(queryText); } - if (dateConf != null) { updateQueryContext(queryContext, dateConf); } } private void updateQueryContext(ChatQueryContext queryContext, DateConf dateConf) { - if (!queryContext.getCandidateQueries().isEmpty()) { - for (SemanticQuery query : queryContext.getCandidateQueries()) { - SemanticParseInfo parseInfo = query.getParseInfo(); - if (queryContext.containsPartitionDimensions(parseInfo.getDataSetId())) { - parseInfo.setDateInfo(dateConf); - } - parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length()); - } - } else { - SemanticParseInfo contextParseInfo = queryContext.getRequest().getContextParseInfo(); - if (QueryManager.containsRuleQuery(contextParseInfo.getQueryMode())) { - RuleSemanticQuery semanticQuery = - QueryManager.createRuleQuery(contextParseInfo.getQueryMode()); - if (queryContext.containsPartitionDimensions(contextParseInfo.getDataSetId())) { - contextParseInfo.setDateInfo(dateConf); - } - contextParseInfo - .setScore(contextParseInfo.getScore() + dateConf.getDetectWord().length()); - semanticQuery.setParseInfo(contextParseInfo); - queryContext.getCandidateQueries().add(semanticQuery); + for (SemanticQuery query : queryContext.getCandidateQueries()) { + SemanticParseInfo parseInfo = query.getParseInfo(); + if (queryContext.containsPartitionDimensions(parseInfo.getDataSetId())) { + parseInfo.setDateInfo(dateConf); } + parseInfo.setScore(parseInfo.getScore() + dateConf.getDetectWord().length()); } }