[improvement][chat]Optimize NL2SQL parsing logic.

This commit is contained in:
jerryjzhang
2024-10-29 20:33:32 +08:00
parent 996cb3df56
commit cbb76550c7
7 changed files with 94 additions and 168 deletions

View File

@@ -9,6 +9,7 @@ import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterType;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import lombok.Builder;
import lombok.Data;
import java.util.Comparator;
@@ -46,8 +47,58 @@ public class SemanticParseInfo {
private String textInfo;
private Map<String, Object> properties = Maps.newHashMap();
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Data
@Builder
public static class DataSetMatchResult {
private double maxMetricSimilarity;
private double maxDatesetSimilarity;
private double totalSimilarity;
private long maxMetricUseCnt;
}
public static class SemanticParseComparator implements Comparator<SemanticParseInfo> {
@Override
public int compare(SemanticParseInfo o1, SemanticParseInfo o2) {
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
if (difference == 0) {
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
if (difference == 0) {
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
}
if (difference == 0) {
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
}
}
return difference >= 0 ? -1 : 1;
}
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
long maxMetricUseCnt = 0L;
for (SchemaElementMatch match : elementMatches) {
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
}
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
if (Objects.nonNull(match.getElement().getUseCnt())) {
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
}
}
totalSimilarity += match.getSimilarity();
}
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
.build();
}
}
private static class SchemaNameLengthComparator implements Comparator<SchemaElement> {
@Override
public int compare(SchemaElement o1, SchemaElement o2) {
if (o1.getOrder() != o2.getOrder()) {
@@ -93,4 +144,19 @@ public class SemanticParseInfo {
}
return limit;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
SemanticParseInfo that = (SemanticParseInfo) o;
return Objects.equals(textInfo, that.textInfo);
}
@Override
public int hashCode() {
return Objects.hashCode(textInfo);
}
}

View File

@@ -1,13 +0,0 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import lombok.Builder;
import lombok.Data;
@Data
@Builder
public class DataSetMatchResult {
private double maxMetricSimilarity;
private double maxDatesetSimilarity;
private double totalSimilarity;
private Long maxMetricUseCnt;
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SchemaMapInfo;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections.CollectionUtils;
@@ -36,8 +37,9 @@ public class HeuristicDataSetResolver implements DataSetResolver {
}
protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dataSetMatchRet = getDataSetMatchResult(schemaMap);
Entry<Long, DataSetMatchResult> selectedDataset =
Map<Long, SemanticParseInfo.DataSetMatchResult> dataSetMatchRet =
getDataSetMatchResult(schemaMap);
Entry<Long, SemanticParseInfo.DataSetMatchResult> selectedDataset =
dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> {
double difference = o1.getValue().getMaxDatesetSimilarity()
- o2.getValue().getMaxDatesetSimilarity();
@@ -63,8 +65,9 @@ public class HeuristicDataSetResolver implements DataSetResolver {
return null;
}
protected Map<Long, DataSetMatchResult> getDataSetMatchResult(SchemaMapInfo schemaMap) {
Map<Long, DataSetMatchResult> dateSetMatchRet = new HashMap<>();
protected Map<Long, SemanticParseInfo.DataSetMatchResult> getDataSetMatchResult(
SchemaMapInfo schemaMap) {
Map<Long, SemanticParseInfo.DataSetMatchResult> dateSetMatchRet = new HashMap<>();
for (Entry<Long, List<SchemaElementMatch>> entry : schemaMap.getDataSetElementMatches()
.entrySet()) {
double maxMetricSimilarity = 0;
@@ -84,7 +87,8 @@ public class HeuristicDataSetResolver implements DataSetResolver {
totalSimilarity += match.getSimilarity();
}
dateSetMatchRet.put(entry.getKey(),
DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
SemanticParseInfo.DataSetMatchResult.builder()
.maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity)
.totalSimilarity(totalSimilarity).build());
}