diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 4762304fe..0720e5284 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.server.parser; +import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.server.pojo.ChatContext; @@ -15,10 +16,12 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.util.ChatAppManager; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum; +import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.response.MapResp; import com.tencent.supersonic.headless.api.pojo.response.ParseResp; @@ -35,6 +38,7 @@ import dev.langchain4j.provider.ModelProvider; import lombok.extern.slf4j.Slf4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.util.CollectionUtils; import java.util.*; import java.util.stream.Collectors; @@ -78,27 +82,24 @@ public class NL2SQLParser implements ChatQueryParser { QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext); queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE); - // inject semantic parse saved by in the chat context - ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class); - ChatContext chatCtx = - chatContextService.getOrCreateContext(parseContext.getRequest().getChatId()); - if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) { - queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); - } - - // for every requested dataSet, recursively invoke rule-based parser - // with different mapModes, unless any valid semantic parse is derived. + // for every requested dataSet, recursively invoke rule-based parser with different + // mapModes Set requestedDatasets = queryNLReq.getDataSetIds(); for (Long datasetId : requestedDatasets) { queryNLReq.setDataSetIds(Collections.singleton(datasetId)); - ChatParseResp parseResp = parseContext.getResponse(); - for (MapModeEnum mode : MapModeEnum.values()) { + ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId()); + for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT, MapModeEnum.MODERATE)) { queryNLReq.setMapModeEnum(mode); doParse(queryNLReq, parseResp); - if (!parseResp.getSelectedParses().isEmpty()) { - break; - } } + if (parseResp.getSelectedParses().isEmpty()) { + queryNLReq.setMapModeEnum(MapModeEnum.LOOSE); + doParse(queryNLReq, parseResp); + } + List sortedParses = parseResp.getSelectedParses().stream() + .sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1) + .collect(Collectors.toList()); + parseContext.getResponse().getSelectedParses().addAll(sortedParses); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java index e8d149ae3..8628a995e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoSortProcessor.java @@ -1,10 +1,7 @@ package com.tencent.supersonic.chat.server.processor.parse; import com.tencent.supersonic.chat.server.pojo.ParseContext; -import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult; import lombok.extern.slf4j.Slf4j; import java.util.*; @@ -18,23 +15,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor { @Override public void process(ParseContext parseContext) { List selectedParses = parseContext.getResponse().getSelectedParses(); - - selectedParses.sort((o1, o2) -> { - DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches()); - DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches()); - - double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity(); - if (difference == 0) { - difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity(); - if (difference == 0) { - difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity(); - } - if (difference == 0) { - difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt(); - } - } - return difference >= 0 ? -1 : 1; - }); + selectedParses.sort(new SemanticParseInfo.SemanticParseComparator()); // re-assign parseId for (int i = 0; i < selectedParses.size(); i++) { SemanticParseInfo parseInfo = selectedParses.get(i); @@ -42,26 +23,4 @@ public class ParseInfoSortProcessor implements ParseResultProcessor { } } - private DataSetMatchResult getDataSetMatchResult(List elementMatches) { - double maxMetricSimilarity = 0; - double maxDatasetSimilarity = 0; - double totalSimilarity = 0; - long maxMetricUseCnt = 0L; - for (SchemaElementMatch match : elementMatches) { - if (SchemaElementType.DATASET.equals(match.getElement().getType())) { - maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity()); - } - if (SchemaElementType.METRIC.equals(match.getElement().getType())) { - maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity()); - if (Objects.nonNull(match.getElement().getUseCnt())) { - maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt()); - } - } - totalSimilarity += match.getSimilarity(); - } - return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity) - .maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity) - .build(); - } - } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java index f0ebbb744..20fa1d489 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticParseInfo.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import com.tencent.supersonic.common.pojo.enums.FilterType; import com.tencent.supersonic.common.pojo.enums.QueryType; import com.tencent.supersonic.headless.api.pojo.request.QueryFilter; +import lombok.Builder; import lombok.Data; import java.util.Comparator; @@ -46,8 +47,58 @@ public class SemanticParseInfo { private String textInfo; private Map properties = Maps.newHashMap(); - private static class SchemaNameLengthComparator implements Comparator { + @Data + @Builder + public static class DataSetMatchResult { + private double maxMetricSimilarity; + private double maxDatesetSimilarity; + private double totalSimilarity; + private long maxMetricUseCnt; + } + public static class SemanticParseComparator implements Comparator { + @Override + public int compare(SemanticParseInfo o1, SemanticParseInfo o2) { + DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches()); + DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches()); + + double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity(); + if (difference == 0) { + difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity(); + if (difference == 0) { + difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity(); + } + if (difference == 0) { + difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt(); + } + } + return difference >= 0 ? -1 : 1; + } + + private DataSetMatchResult getDataSetMatchResult(List elementMatches) { + double maxMetricSimilarity = 0; + double maxDatasetSimilarity = 0; + double totalSimilarity = 0; + long maxMetricUseCnt = 0L; + for (SchemaElementMatch match : elementMatches) { + if (SchemaElementType.DATASET.equals(match.getElement().getType())) { + maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity()); + } + if (SchemaElementType.METRIC.equals(match.getElement().getType())) { + maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity()); + if (Objects.nonNull(match.getElement().getUseCnt())) { + maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt()); + } + } + totalSimilarity += match.getSimilarity(); + } + return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity) + .maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity) + .build(); + } + } + + private static class SchemaNameLengthComparator implements Comparator { @Override public int compare(SchemaElement o1, SchemaElement o2) { if (o1.getOrder() != o2.getOrder()) { @@ -93,4 +144,19 @@ public class SemanticParseInfo { } return limit; } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + SemanticParseInfo that = (SemanticParseInfo) o; + return Objects.equals(textInfo, that.textInfo); + } + + @Override + public int hashCode() { + return Objects.hashCode(textInfo); + } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java deleted file mode 100644 index 8ef63aea3..000000000 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/DataSetMatchResult.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.tencent.supersonic.headless.chat.parser.llm; - -import lombok.Builder; -import lombok.Data; - -@Data -@Builder -public class DataSetMatchResult { - private double maxMetricSimilarity; - private double maxDatesetSimilarity; - private double totalSimilarity; - private Long maxMetricUseCnt; -} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java index 9302db6c5..c1f6591a1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementType; import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.chat.ChatQueryContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; @@ -36,8 +37,9 @@ public class HeuristicDataSetResolver implements DataSetResolver { } protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) { - Map dataSetMatchRet = getDataSetMatchResult(schemaMap); - Entry selectedDataset = + Map dataSetMatchRet = + getDataSetMatchResult(schemaMap); + Entry selectedDataset = dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> { double difference = o1.getValue().getMaxDatesetSimilarity() - o2.getValue().getMaxDatesetSimilarity(); @@ -63,8 +65,9 @@ public class HeuristicDataSetResolver implements DataSetResolver { return null; } - protected Map getDataSetMatchResult(SchemaMapInfo schemaMap) { - Map dateSetMatchRet = new HashMap<>(); + protected Map getDataSetMatchResult( + SchemaMapInfo schemaMap) { + Map dateSetMatchRet = new HashMap<>(); for (Entry> entry : schemaMap.getDataSetElementMatches() .entrySet()) { double maxMetricSimilarity = 0; @@ -84,7 +87,8 @@ public class HeuristicDataSetResolver implements DataSetResolver { totalSimilarity += match.getSimilarity(); } dateSetMatchRet.put(entry.getKey(), - DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity) + SemanticParseInfo.DataSetMatchResult.builder() + .maxMetricSimilarity(maxMetricSimilarity) .maxDatesetSimilarity(maxDatasetSimilarity) .totalSimilarity(totalSimilarity).build()); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 854614edf..624ae8a1a 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -129,8 +129,7 @@ public class S2VisitsDemo extends S2BaseDemo { public void addSampleChats(Integer agentId) { Long chatId = chatManageService.addChat(defaultUser, "样例对话1", agentId); submitText(chatId.intValue(), agentId, "超音数 访问次数"); - submitText(chatId.intValue(), agentId, "按部门统计"); - submitText(chatId.intValue(), agentId, "查询近30天"); + submitText(chatId.intValue(), agentId, "按部门统计近7天访问次数"); submitText(chatId.intValue(), agentId, "alice 停留时长"); submitText(chatId.intValue(), agentId, "访问次数最高的部门"); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java deleted file mode 100644 index a5a46403d..000000000 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java +++ /dev/null @@ -1,90 +0,0 @@ -package com.tencent.supersonic.chat; - -import com.tencent.supersonic.chat.api.pojo.response.QueryResult; -import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.enums.QueryType; -import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery; -import com.tencent.supersonic.util.DataUtils; -import org.junit.jupiter.api.Order; -import org.junit.jupiter.api.Test; - -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; - -public class MultiTurnsTest extends BaseTest { - - @Test - @Order(1) - public void queryTest_01() throws Exception { - QueryResult actualResult = submitMultiTurnChat("alice的访问次数", DataUtils.metricAgentId, - DataUtils.MULTI_TURNS_CHAT_ID); - - QueryResult expectedResult = new QueryResult(); - SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); - expectedResult.setChatContext(expectedParseInfo); - - expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); - expectedParseInfo.setAggType(NONE); - - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); - - expectedParseInfo.getDimensionFilters().add( - DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); - - expectedParseInfo.setDateInfo( - DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); - expectedParseInfo.setQueryType(QueryType.AGGREGATE); - - assertQueryResult(expectedResult, actualResult); - } - - @Test - @Order(2) - public void queryTest_02() throws Exception { - QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId, - DataUtils.MULTI_TURNS_CHAT_ID); - - QueryResult expectedResult = new QueryResult(); - SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); - expectedResult.setChatContext(expectedParseInfo); - expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); - expectedParseInfo.setAggType(NONE); - - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长")); - - expectedParseInfo.getDimensionFilters().add( - DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); - - expectedParseInfo.setDateInfo( - DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); - expectedParseInfo.setQueryType(QueryType.AGGREGATE); - - assertQueryResult(expectedResult, actualResult); - } - - @Test - @Order(3) - public void queryTest_03() throws Exception { - QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId, - DataUtils.MULTI_TURNS_CHAT_ID); - - QueryResult expectedResult = new QueryResult(); - SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); - expectedResult.setChatContext(expectedParseInfo); - - expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE); - expectedParseInfo.setAggType(NONE); - - expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长")); - - expectedParseInfo.getDimensionFilters() - .add(DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "lucy", "用户", 2L)); - - expectedParseInfo.setDateInfo( - DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); - expectedParseInfo.setQueryType(QueryType.AGGREGATE); - - assertQueryResult(expectedResult, actualResult); - } -}