mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 20:25:12 +00:00
(improvement)(Headless) distinct select fields in S2CorrectSQL (#912)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
package com.tencent.supersonic.headless.core.chat.corrector;
|
||||
|
||||
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.core.pojo.QueryContext;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -28,5 +29,12 @@ public class GrammarCorrector extends BaseSemanticCorrector {
|
||||
for (BaseSemanticCorrector corrector : correctors) {
|
||||
corrector.correct(queryContext, semanticParseInfo);
|
||||
}
|
||||
removeSameFieldFromSelect(semanticParseInfo);
|
||||
}
|
||||
|
||||
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
|
||||
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
|
||||
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
|
||||
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,11 +61,12 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
|
||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
|
||||
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
|
||||
String queryText = queryCtx.getQueryText();
|
||||
|
||||
LLMReq llmReq = new LLMReq();
|
||||
|
||||
llmReq.setQueryText(queryText);
|
||||
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
|
||||
llmReq.setFilterCondition(filterCondition);
|
||||
@@ -103,7 +104,7 @@ public class LLMRequestService {
|
||||
}
|
||||
|
||||
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
|
||||
LLMParserConfig llmParserConfig) {
|
||||
LLMParserConfig llmParserConfig) {
|
||||
|
||||
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
|
||||
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.collections4.MapUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
@@ -28,6 +29,7 @@ public class LLMSqlParser implements SemanticParser {
|
||||
try {
|
||||
//2.get dataSetId from queryCtx and chatCtx.
|
||||
Long dataSetId = requestService.getDataSetId(queryCtx);
|
||||
log.info("dataSetId:{}", dataSetId);
|
||||
if (dataSetId == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -6,11 +6,11 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Map;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -85,7 +85,7 @@ public class SqlPromptGenerator {
|
||||
}
|
||||
|
||||
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
|
||||
int numSelfConsistency) {
|
||||
int numSelfConsistency) {
|
||||
List<List<Map<String, String>>> results = new ArrayList<>();
|
||||
for (int i = 0; i < numSelfConsistency; i++) {
|
||||
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
|
||||
@@ -118,7 +118,7 @@ public class SqlPromptGenerator {
|
||||
}
|
||||
|
||||
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
|
||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||
List<List<Map<String, String>>> fewshotExampleListPool) {
|
||||
List<String> sqlPromptPool = new ArrayList<>();
|
||||
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
|
||||
String schemaLinkStr = schemaLinkStrPool.get(i);
|
||||
@@ -129,4 +129,4 @@ public class SqlPromptGenerator {
|
||||
return sqlPromptPool;
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
Reference in New Issue
Block a user