(improvement)(Headless) distinct select fields in S2CorrectSQL (#912)

This commit is contained in:
mainmain
2024-04-16 16:30:00 +08:00
committed by GitHub
parent 6e0fc87a57
commit 25a618b175
13 changed files with 87 additions and 18 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.core.chat.corrector;
import com.tencent.supersonic.common.util.jsqlparser.SqlRemoveHelper;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.core.pojo.QueryContext;
import lombok.extern.slf4j.Slf4j;
@@ -28,5 +29,12 @@ public class GrammarCorrector extends BaseSemanticCorrector {
for (BaseSemanticCorrector corrector : correctors) {
corrector.correct(queryContext, semanticParseInfo);
}
removeSameFieldFromSelect(semanticParseInfo);
}
public void removeSameFieldFromSelect(SemanticParseInfo semanticParseInfo) {
String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectS2SQL();
correctS2SQL = SqlRemoveHelper.removeSameFieldFromSelect(correctS2SQL);
semanticParseInfo.getSqlInfo().setCorrectS2SQL(correctS2SQL);
}
}

View File

@@ -61,11 +61,12 @@ public class LLMRequestService {
}
public LLMReq getLlmReq(QueryContext queryCtx, Long dataSetId,
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
SemanticSchema semanticSchema, List<LLMReq.ElementValue> linkingValues) {
Map<Long, String> dataSetIdToName = semanticSchema.getDataSetIdToName();
String queryText = queryCtx.getQueryText();
LLMReq llmReq = new LLMReq();
llmReq.setQueryText(queryText);
LLMReq.FilterCondition filterCondition = new LLMReq.FilterCondition();
llmReq.setFilterCondition(filterCondition);
@@ -103,7 +104,7 @@ public class LLMRequestService {
}
protected List<String> getFieldNameList(QueryContext queryCtx, Long dataSetId,
LLMParserConfig llmParserConfig) {
LLMParserConfig llmParserConfig) {
Set<String> results = getTopNFieldNames(queryCtx, dataSetId, llmParserConfig);

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMResp;
import com.tencent.supersonic.headless.core.chat.query.llm.s2sql.LLMSqlResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.collections4.MapUtils;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -28,6 +29,7 @@ public class LLMSqlParser implements SemanticParser {
try {
//2.get dataSetId from queryCtx and chatCtx.
Long dataSetId = requestService.getDataSetId(queryCtx);
log.info("dataSetId:{}", dataSetId);
if (dataSetId == null) {
return;
}

View File

@@ -6,11 +6,11 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Component;
import java.util.List;
import java.util.ArrayList;
import java.util.Map;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Component
@Slf4j
@@ -85,7 +85,7 @@ public class SqlPromptGenerator {
}
public List<List<Map<String, String>>> getExampleCombos(List<Map<String, String>> exampleList, int numFewShots,
int numSelfConsistency) {
int numSelfConsistency) {
List<List<Map<String, String>>> results = new ArrayList<>();
for (int i = 0; i < numSelfConsistency; i++) {
List<Map<String, String>> shuffledList = new ArrayList<>(exampleList);
@@ -118,7 +118,7 @@ public class SqlPromptGenerator {
}
public List<String> generateSqlPromptPool(LLMReq llmReq, List<String> schemaLinkStrPool,
List<List<Map<String, String>>> fewshotExampleListPool) {
List<List<Map<String, String>>> fewshotExampleListPool) {
List<String> sqlPromptPool = new ArrayList<>();
for (int i = 0; i < schemaLinkStrPool.size(); i++) {
String schemaLinkStr = schemaLinkStrPool.get(i);
@@ -129,4 +129,4 @@ public class SqlPromptGenerator {
return sqlPromptPool;
}
}
}

View File

@@ -15,6 +15,7 @@ import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;