mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-14 22:25:19 +00:00
[improvement][chat]Optimize NL2SQL parsing logic.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.ChatParseResp;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
@@ -15,10 +16,12 @@ import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.MapResp;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
|
||||
@@ -35,6 +38,7 @@ import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -78,27 +82,24 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
|
||||
// inject semantic parse saved by in the chat context
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx =
|
||||
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
||||
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
|
||||
// for every requested dataSet, recursively invoke rule-based parser
|
||||
// with different mapModes, unless any valid semantic parse is derived.
|
||||
// for every requested dataSet, recursively invoke rule-based parser with different
|
||||
// mapModes
|
||||
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
||||
for (Long datasetId : requestedDatasets) {
|
||||
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
ChatParseResp parseResp = new ChatParseResp(parseContext.getRequest().getQueryId());
|
||||
for (MapModeEnum mode : Lists.newArrayList(MapModeEnum.STRICT, MapModeEnum.MODERATE)) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
if (!parseResp.getSelectedParses().isEmpty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (parseResp.getSelectedParses().isEmpty()) {
|
||||
queryNLReq.setMapModeEnum(MapModeEnum.LOOSE);
|
||||
doParse(queryNLReq, parseResp);
|
||||
}
|
||||
List<SemanticParseInfo> sortedParses = parseResp.getSelectedParses().stream()
|
||||
.sorted(new SemanticParseInfo.SemanticParseComparator()).limit(1)
|
||||
.collect(Collectors.toList());
|
||||
parseContext.getResponse().getSelectedParses().addAll(sortedParses);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.*;
|
||||
@@ -18,23 +15,7 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
|
||||
|
||||
selectedParses.sort((o1, o2) -> {
|
||||
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
|
||||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
||||
|
||||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
|
||||
}
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
|
||||
}
|
||||
}
|
||||
return difference >= 0 ? -1 : 1;
|
||||
});
|
||||
selectedParses.sort(new SemanticParseInfo.SemanticParseComparator());
|
||||
// re-assign parseId
|
||||
for (int i = 0; i < selectedParses.size(); i++) {
|
||||
SemanticParseInfo parseInfo = selectedParses.get(i);
|
||||
@@ -42,26 +23,4 @@ public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
}
|
||||
}
|
||||
|
||||
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
|
||||
double maxMetricSimilarity = 0;
|
||||
double maxDatasetSimilarity = 0;
|
||||
double totalSimilarity = 0;
|
||||
long maxMetricUseCnt = 0L;
|
||||
for (SchemaElementMatch match : elementMatches) {
|
||||
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
|
||||
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
|
||||
}
|
||||
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
|
||||
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
|
||||
if (Objects.nonNull(match.getElement().getUseCnt())) {
|
||||
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
|
||||
}
|
||||
}
|
||||
totalSimilarity += match.getSimilarity();
|
||||
}
|
||||
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
|
||||
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user