[improvement][chat]Optimize NL2SQL parsing logic.

This commit is contained in:
jerryjzhang
2024-10-29 20:33:32 +08:00
parent 996cb3df56
commit cbb76550c7
7 changed files with 94 additions and 168 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.parser;
import com.google.common.collect.Lists;
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.pojo.ChatContext;
@@ -15,10 +16,12 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
@@ -35,6 +38,7 @@ import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.CollectionUtils;
import java.util.*;
import java.util.stream.Collectors;
@@ -78,27 +82,24 @@ public class NL2SQLParser implements ChatQueryParser {
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
// inject semantic parse saved by in the chat context
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
ChatContext chatCtx =
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
}
// for every requested dataSet, recursively invoke rule-based parser
// with different mapModes, unless any valid semantic parse is derived.
// for every requested dataSet, recursively invoke rule-based parser with different
// mapModes
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
for (Long datasetId : requestedDatasets) {
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
ChatParseResp parseResp = parseContext.getResponse();
for (MapModeEnum mode : MapModeEnum.values()) {
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT, MapModeEnum.MODERATE)) {
queryNLReq.setMapModeEnum(mode);
doParse(queryNLReq, parseResp);
if (!parseResp.getSelectedParses().isEmpty()) {
break;
}
}
if (parseResp.getSelectedParses().isEmpty()) {
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
doParse(queryNLReq, parseResp);
}
List<SemanticParseInfo> sortedParses = parseResp.getSelectedParses().stream()
.sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1)
.collect(Collectors.toList());
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
}
}

View File

@@ -1,10 +1,7 @@
package com.tencent.supersonic.chat.server.processor.parse;
import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
import lombok.extern.slf4j.Slf4j;
import java.util.*;
@@ -18,23 +15,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
@Override
public void process(ParseContext parseContext) {
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
selectedParses.sort((o1, o2) -> {
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
if (difference == 0) {
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
if (difference == 0) {
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
}
if (difference == 0) {
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
}
}
return difference >= 0 ? -1 : 1;
});
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
// re-assign parseId
for (int i = 0; i < selectedParses.size(); i++) {
SemanticParseInfo parseInfo = selectedParses.get(i);
@@ -42,26 +23,4 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
}
}
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
double maxMetricSimilarity = 0;
double maxDatasetSimilarity = 0;
double totalSimilarity = 0;
long maxMetricUseCnt = 0L;
for (SchemaElementMatch match : elementMatches) {
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
}
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
if (Objects.nonNull(match.getElement().getUseCnt())) {
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
}
}
totalSimilarity += match.getSimilarity();
}
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
.build();
}
}