mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
[improvement][chat]Modify core workflow of NL2SQLParser, always invoking rule-based parsers first.#1729
This commit is contained in:
@@ -22,4 +22,5 @@ public class ChatParseResp {
|
||||
public ChatParseResp(Long queryId) {
|
||||
this.queryId = queryId;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -36,12 +36,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
@@ -78,29 +73,46 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
return;
|
||||
}
|
||||
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx =
|
||||
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
||||
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
|
||||
if (parseContext.needRuleParse()) {
|
||||
// first go with rule-based parsers unless the user has already selected one parse.
|
||||
if (Objects.isNull(parseContext.getRequest().getSelectedParse())) {
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
|
||||
// inject semantic parse saved by in the chat context
|
||||
ChatContextService chatContextService = ContextUtils.getBean(ChatContextService.class);
|
||||
ChatContext chatCtx =
|
||||
chatContextService.getOrCreateContext(parseContext.getRequest().getChatId());
|
||||
if (chatCtx != null && Objects.isNull(queryNLReq.getContextParseInfo())) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
|
||||
// for every requested dataSet, recursively invoke rule-based parser
|
||||
// with different mapModes, unless any valid semantic parse is derived.
|
||||
Set<Long> requestedDatasets = queryNLReq.getDataSetIds();
|
||||
for (Long datasetId : requestedDatasets) {
|
||||
queryNLReq.setDataSetIds(Collections.singleton(datasetId));
|
||||
ChatParseResp parseResp = parseContext.getResponse();
|
||||
for (MapModeEnum mode : MapModeEnum.values()) {
|
||||
queryNLReq.setMapModeEnum(mode);
|
||||
doParse(queryNLReq, parseResp);
|
||||
if (!parseResp.getSelectedParses().isEmpty()) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// next go with llm-based parsers unless LLM is disabled or use feedback is needed.
|
||||
if (parseContext.needLLMParse() && !parseContext.needFeedback()) {
|
||||
SemanticParseInfo selectedParse = parseContext.getRequest().getSelectedParse();
|
||||
queryNLReq.setSelectedParseInfo(Objects.nonNull(selectedParse) ? selectedParse
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildQueryNLReq(parseContext);
|
||||
queryNLReq.setText2SQLType(Text2SQLType.LLM_OR_RULE);
|
||||
|
||||
// either the user or the system selects one parse from the candidate parses.
|
||||
SemanticParseInfo userSelectParse = parseContext.getRequest().getSelectedParse();
|
||||
queryNLReq.setSelectedParseInfo(Objects.nonNull(userSelectParse) ? userSelectParse
|
||||
: parseContext.getResponse().getSelectedParses().get(0));
|
||||
queryNLReq.setText2SQLType(Text2SQLType.RULE_AND_LLM);
|
||||
parseContext.getResponse().getSelectedParses().clear();
|
||||
|
||||
parseContext.setResponse(new ChatParseResp(parseContext.getResponse().getQueryId()));
|
||||
rewriteMultiTurn(parseContext, queryNLReq);
|
||||
addDynamicExemplars(parseContext, queryNLReq);
|
||||
doParse(queryNLReq, parseContext.getResponse());
|
||||
|
||||
@@ -31,10 +31,6 @@ public class ParseContext {
|
||||
&& response.getSelectedParses().size() > 1);
|
||||
}
|
||||
|
||||
public boolean needRuleParse() {
|
||||
return Objects.isNull(request.getSelectedParse());
|
||||
}
|
||||
|
||||
public boolean needLLMParse() {
|
||||
return enableLLM() && (Objects.nonNull(request.getSelectedParse())
|
||||
|| !response.getSelectedParses().isEmpty());
|
||||
|
||||
@@ -1,30 +1,67 @@
|
||||
package com.tencent.supersonic.chat.server.processor.parse;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementType;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.DataSetMatchResult;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* ParseInfoSortProcessor sorts candidate parse info based on certain algorithm. \
|
||||
**/
|
||||
@Slf4j
|
||||
public class ParseInfoSortProcessor implements ParseResultProcessor {
|
||||
|
||||
@Override
|
||||
public void process(ParseContext parseContext) {
|
||||
Set<String> parseInfoText = Sets.newHashSet();
|
||||
List<SemanticParseInfo> sortedParseInfo = Lists.newArrayList();
|
||||
List<SemanticParseInfo> selectedParses = parseContext.getResponse().getSelectedParses();
|
||||
|
||||
parseContext.getResponse().getSelectedParses().forEach(p -> {
|
||||
if (!parseInfoText.contains(p.getTextInfo())) {
|
||||
sortedParseInfo.add(p);
|
||||
parseInfoText.add(p.getTextInfo());
|
||||
selectedParses.sort((o1, o2) -> {
|
||||
DataSetMatchResult mr1 = getDataSetMatchResult(o1.getElementMatches());
|
||||
DataSetMatchResult mr2 = getDataSetMatchResult(o2.getElementMatches());
|
||||
|
||||
double difference = mr1.getMaxDatesetSimilarity() - mr2.getMaxDatesetSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricSimilarity() - mr2.getMaxMetricSimilarity();
|
||||
if (difference == 0) {
|
||||
difference = mr1.getTotalSimilarity() - mr2.getTotalSimilarity();
|
||||
}
|
||||
if (difference == 0) {
|
||||
difference = mr1.getMaxMetricUseCnt() - mr2.getMaxMetricUseCnt();
|
||||
}
|
||||
}
|
||||
return difference >= 0 ? -1 : 1;
|
||||
});
|
||||
|
||||
sortedParseInfo.sort((o1, o2) -> o1.getScore() - o2.getScore() > 0 ? 1 : 0);
|
||||
parseContext.getResponse().setSelectedParses(sortedParseInfo);
|
||||
// re-assign parseId
|
||||
for (int i = 0; i < selectedParses.size(); i++) {
|
||||
SemanticParseInfo parseInfo = selectedParses.get(i);
|
||||
parseInfo.setId(i + 1);
|
||||
}
|
||||
}
|
||||
|
||||
private DataSetMatchResult getDataSetMatchResult(List<SchemaElementMatch> elementMatches) {
|
||||
double maxMetricSimilarity = 0;
|
||||
double maxDatasetSimilarity = 0;
|
||||
double totalSimilarity = 0;
|
||||
long maxMetricUseCnt = 0L;
|
||||
for (SchemaElementMatch match : elementMatches) {
|
||||
if (SchemaElementType.DATASET.equals(match.getElement().getType())) {
|
||||
maxDatasetSimilarity = Math.max(maxDatasetSimilarity, match.getSimilarity());
|
||||
}
|
||||
if (SchemaElementType.METRIC.equals(match.getElement().getType())) {
|
||||
maxMetricSimilarity = Math.max(maxMetricSimilarity, match.getSimilarity());
|
||||
if (Objects.nonNull(match.getElement().getUseCnt())) {
|
||||
maxMetricUseCnt = Math.max(maxMetricUseCnt, match.getElement().getUseCnt());
|
||||
}
|
||||
}
|
||||
totalSimilarity += match.getSimilarity();
|
||||
}
|
||||
return DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity)
|
||||
.maxDatesetSimilarity(maxDatasetSimilarity).totalSimilarity(totalSimilarity)
|
||||
.build();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ public class QueryReqConverter {
|
||||
QueryNLReq queryNLReq = new QueryNLReq();
|
||||
BeanMapper.mapper(parseContext.getRequest(), queryNLReq);
|
||||
queryNLReq.setText2SQLType(
|
||||
parseContext.enableLLM() ? Text2SQLType.RULE_AND_LLM : Text2SQLType.ONLY_RULE);
|
||||
parseContext.enableLLM() ? Text2SQLType.LLM_OR_RULE : Text2SQLType.ONLY_RULE);
|
||||
queryNLReq.setDataSetIds(getDataSetIds(parseContext));
|
||||
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
||||
queryNLReq.setSelectedParseInfo(parseContext.getRequest().getSelectedParse());
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
package com.tencent.supersonic.common.pojo.enums;
|
||||
|
||||
public enum Text2SQLType {
|
||||
ONLY_RULE, ONLY_LLM, RULE_AND_LLM;
|
||||
ONLY_RULE, ONLY_LLM, LLM_OR_RULE;
|
||||
|
||||
public boolean enableRule() {
|
||||
return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM);
|
||||
return this.equals(ONLY_RULE) || this.equals(LLM_OR_RULE);
|
||||
}
|
||||
|
||||
public boolean enableLLM() {
|
||||
return this.equals(ONLY_LLM) || this.equals(RULE_AND_LLM);
|
||||
return this.equals(ONLY_LLM) || this.equals(LLM_OR_RULE);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ public class QueryNLReq extends SemanticQueryReq {
|
||||
private User user;
|
||||
private QueryFilters queryFilters;
|
||||
private boolean saveAnswer = true;
|
||||
private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM;
|
||||
private Text2SQLType text2SQLType = Text2SQLType.LLM_OR_RULE;
|
||||
private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private Map<String, ChatApp> chatAppConfig;
|
||||
|
||||
@@ -60,15 +60,12 @@ public class QueryReqBuilder {
|
||||
queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName)
|
||||
.collect(Collectors.toList()));
|
||||
queryStructReq.setLimit(parseInfo.getLimit());
|
||||
// only one metric is queried at once
|
||||
Set<SchemaElement> metrics = parseInfo.getMetrics();
|
||||
if (!CollectionUtils.isEmpty(metrics)) {
|
||||
SchemaElement metricElement = parseInfo.getMetrics().iterator().next();
|
||||
Set<Order> order =
|
||||
getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement);
|
||||
queryStructReq
|
||||
.setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
|
||||
queryStructReq.setOrders(new ArrayList<>(order));
|
||||
|
||||
for (SchemaElement metricElement : parseInfo.getMetrics()) {
|
||||
queryStructReq.getAggregators()
|
||||
.addAll(getAggregatorByMetric(parseInfo.getAggType(), metricElement));
|
||||
queryStructReq.setOrders(new ArrayList<>(
|
||||
getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement)));
|
||||
}
|
||||
|
||||
deletionDuplicated(queryStructReq);
|
||||
|
||||
@@ -40,6 +40,7 @@ public class DataUtils {
|
||||
public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) {
|
||||
ChatParseReq chatParseReq = new ChatParseReq();
|
||||
chatParseReq.setQueryText(query);
|
||||
chatParseReq.setAgentId(metricAgentId);
|
||||
chatParseReq.setChatId(id);
|
||||
chatParseReq.setUser(user_test);
|
||||
chatParseReq.setDisableLLM(!enableLLM);
|
||||
|
||||
Reference in New Issue
Block a user