mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[improvement][chat]Optimize NL2SQL parsing logic.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
@@ -15,10 +16,12 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -35,6 +38,7 @@ import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -78,27 +82,24 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
|
||||
// inject semantic parse saved by in the chat context
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx =
|
||||
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
||||
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
|
||||
// for every requested dataSet, recursively invoke rule-based parser
|
||||
// with different mapModes, unless any valid semantic parse is derived.
|
||||
// for every requested dataSet, recursively invoke rule-based parser with different
|
||||
// mapModes
|
||||
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
||||
for (Long datasetId : requestedDatasets) {
|
||||
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
|
||||
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT, MapModeEnum.MODERATE)) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
if (!parseResp.getSelectedParses().isEmpty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
List<SemanticParseInfo> sortedParses = parseResp.getSelectedParses().stream()
|
||||
.sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1)
|
||||
.collect(Collectors.toList());
|
||||
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.*;
|
||||
@@ -18,23 +15,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
|
||||
|
||||
selectedParses.sort((o1, o2) -> {
|
||||
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
|
||||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
||||
|
||||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
|
||||
}
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
|
||||
}
|
||||
}
|
||||
return difference >= 0 ? -1 : 1;
|
||||
});
|
||||
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
|
||||
// re-assign parseId
|
||||
for (int i = 0; i < selectedParses.size(); i++) {
|
||||
SemanticParseInfo parseInfo = selectedParses.get(i);
|
||||
@@ -42,26 +23,4 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
|
||||
double maxMetricSimilarity = 0;
|
||||
double maxDatasetSimilarity = 0;
|
||||
double totalSimilarity = 0;
|
||||
long maxMetricUseCnt = 0L;
|
||||
for (SchemaElementMatch match : elementMatches) {
|
||||
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
|
||||
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
|
||||
}
|
||||
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
|
||||
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
|
||||
if (Objects.nonNull(match.getElement().getUseCnt())) {
|
||||
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
|
||||
}
|
||||
}
|
||||
totalSimilarity += match.getSimilarity();
|
||||
}
|
||||
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
|
||||
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterType;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.Comparator;
|
||||
@@ -46,8 +47,58 @@ public class SemanticParseInfo {
|
||||
private String textInfo;
|
||||
private Map<String, Object> properties = Maps.newHashMap();
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
@Data
|
||||
@Builder
|
||||
public static class DataSetMatchResult {
|
||||
private double maxMetricSimilarity;
|
||||
private double maxDatesetSimilarity;
|
||||
private double totalSimilarity;
|
||||
private long maxMetricUseCnt;
|
||||
}
|
||||
|
||||
public static class SemanticParseComparator implements Comparator<SemanticParseInfo> {
|
||||
@Override
|
||||
public int compare(SemanticParseInfo o1, SemanticParseInfo o2) {
|
||||
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
|
||||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
||||
|
||||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
|
||||
}
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
|
||||
}
|
||||
}
|
||||
return difference >= 0 ? -1 : 1;
|
||||
}
|
||||
|
||||
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
|
||||
double maxMetricSimilarity = 0;
|
||||
double maxDatasetSimilarity = 0;
|
||||
double totalSimilarity = 0;
|
||||
long maxMetricUseCnt = 0L;
|
||||
for (SchemaElementMatch match : elementMatches) {
|
||||
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
|
||||
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
|
||||
}
|
||||
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
|
||||
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
|
||||
if (Objects.nonNull(match.getElement().getUseCnt())) {
|
||||
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
|
||||
}
|
||||
}
|
||||
totalSimilarity += match.getSimilarity();
|
||||
}
|
||||
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
|
||||
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
|
||||
.build();
|
||||
}
|
||||
}
|
||||
|
||||
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
|
||||
@Override
|
||||
public int compare(SchemaElement o1, SchemaElement o2) {
|
||||
if (o1.getOrder() != o2.getOrder()) {
|
||||
@@ -93,4 +144,19 @@ public class SemanticParseInfo {
|
||||
}
|
||||
return limit;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (o == null || getClass() != o.getClass())
|
||||
return false;
|
||||
SemanticParseInfo that = (SemanticParseInfo) o;
|
||||
return Objects.equals(textInfo, that.textInfo);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(textInfo);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
public class DataSetMatchResult {
|
||||
private double maxMetricSimilarity;
|
||||
private double maxDatesetSimilarity;
|
||||
private double totalSimilarity;
|
||||
private Long maxMetricUseCnt;
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.ChatQueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections.CollectionUtils;
|
||||
@@ -36,8 +37,9 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
}
|
||||
|
||||
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
|
||||
Entry<Long, DataSetMatchResult> selectedDataset =
|
||||
Map<Long, SemanticParseInfo.DataSetMatchResult> dataSetMatchRet =
|
||||
getDataSetMatchResult(schemaMap);
|
||||
Entry<Long, SemanticParseInfo.DataSetMatchResult> selectedDataset =
|
||||
dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> {
|
||||
double difference = o1.getValue().getMaxDatesetSimilarity()
|
||||
- o2.getValue().getMaxDatesetSimilarity();
|
||||
@@ -63,8 +65,9 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
return null;
|
||||
}
|
||||
|
||||
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
|
||||
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
|
||||
protected Map<Long, SemanticParseInfo.DataSetMatchResult> getDataSetMatchResult(
|
||||
SchemaMapInfo schemaMap) {
|
||||
Map<Long, SemanticParseInfo.DataSetMatchResult> dateSetMatchRet = new HashMap<>();
|
||||
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches()
|
||||
.entrySet()) {
|
||||
double maxMetricSimilarity = 0;
|
||||
@@ -84,7 +87,8 @@ public class HeuristicDataSetResolver implements DataSetResolver {
|
||||
totalSimilarity += match.getSimilarity();
|
||||
}
|
||||
dateSetMatchRet.put(entry.getKey(),
|
||||
DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
|
||||
SemanticParseInfo.DataSetMatchResult.builder()
|
||||
.maxMetricSimilarity(maxMetricSimilarity)
|
||||
.maxDatesetSimilarity(maxDatasetSimilarity)
|
||||
.totalSimilarity(totalSimilarity).build());
|
||||
}
|
||||
|
||||
@@ -129,8 +129,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
public void addSampleChats(Integer agentId) {
|
||||
Long chatId = chatManageService.addChat(defaultUser, "样例对话1", agentId);
|
||||
submitText(chatId.intValue(), agentId, "超音数 访问次数");
|
||||
submitText(chatId.intValue(), agentId, "按部门统计");
|
||||
submitText(chatId.intValue(), agentId, "查询近30天");
|
||||
submitText(chatId.intValue(), agentId, "按部门统计近7天访问次数");
|
||||
submitText(chatId.intValue(), agentId, "alice 停留时长");
|
||||
submitText(chatId.intValue(), agentId, "访问次数最高的部门");
|
||||
}
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
package com.tencent.supersonic.chat;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.common.pojo.DateConf;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.QueryType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
import org.junit.jupiter.api.Order;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE;
|
||||
|
||||
public class MultiTurnsTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
@Order(1)
|
||||
public void queryTest_01() throws Exception {
|
||||
QueryResult actualResult = submitMultiTurnChat("alice的访问次数", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
|
||||
|
||||
expectedParseInfo.getDimensionFilters().add(
|
||||
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(2)
|
||||
public void queryTest_02() throws Exception {
|
||||
QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长"));
|
||||
|
||||
expectedParseInfo.getDimensionFilters().add(
|
||||
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
@Order(3)
|
||||
public void queryTest_03() throws Exception {
|
||||
QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId,
|
||||
DataUtils.MULTI_TURNS_CHAT_ID);
|
||||
|
||||
QueryResult expectedResult = new QueryResult();
|
||||
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
|
||||
expectedResult.setChatContext(expectedParseInfo);
|
||||
|
||||
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
|
||||
expectedParseInfo.setAggType(NONE);
|
||||
|
||||
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长"));
|
||||
|
||||
expectedParseInfo.getDimensionFilters()
|
||||
.add(DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "lucy", "用户", 2L));
|
||||
|
||||
expectedParseInfo.setDateInfo(
|
||||
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
|
||||
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
|
||||
|
||||
assertQueryResult(expectedResult, actualResult);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user