(improvement)(common)Rename SqlExemplar to Text2SQLExemplar.

This commit is contained in:
jerryjzhang
2024-07-18 14:19:56 +08:00
parent 2eac301076
commit 2425067091
16 changed files with 57 additions and 56 deletions

View File

@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
@@ -52,7 +52,7 @@ public class ChatQueryContext {
private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;
private List<Text2SQLExemplar> dynamicExemplars;
private SemanticParseInfo contextParseInfo;
public List<SemanticQuery> getCandidateQueries() {

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
@@ -36,13 +36,13 @@ public class LLMResponseService {
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
properties.put("type", "internal");
SqlExemplar exemplar = SqlExemplar.builder()
Text2SQLExemplar exemplar = Text2SQLExemplar.builder()
.question(queryCtx.getQueryText())
.sideInfo(parseResult.getLlmResp().getSideInfo())
.dbSchema(parseResult.getLlmResp().getSchema())
.sql(parseResult.getLlmResp().getSqlOutput())
.build();
properties.put(SqlExemplar.PROPERTY_KEY, exemplar);
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());

View File

@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.data.message.AiMessage;
@@ -45,11 +45,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
llmResp.setQuery(llmReq.getQueryText());
//1.recall exemplars
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
//2.generate sql generation prompt for each self-consistency inference
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
for (List<SqlExemplar> exemplars : exemplarsList) {
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
llmReq.setDynamicExemplars(exemplars);
Prompt prompt = generatePrompt(llmReq, llmResp);
prompt2Exemplar.put(prompt, exemplars);
@@ -61,9 +61,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text();
output2Prompt.put(result, prompt);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
String sqlOutput = StringUtils.normalizeSpace(response.content().text());
output2Prompt.put(sqlOutput, prompt);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutput);
}
);
@@ -71,7 +71,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
Lists.newArrayList(output2Prompt.keySet()));
llmResp.setSqlOutput(sqlMapPair.getLeft());
List<SqlExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
List<Text2SQLExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
return llmResp;
@@ -79,7 +79,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
StringBuilder exemplars = new StringBuilder();
for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) {
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
exemplar.getQuestion(), exemplar.getDbSchema(),
exemplar.getSideInfo(), exemplar.getSql());

View File

@@ -1,7 +1,7 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.service.ExemplarService;
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
@@ -29,12 +29,12 @@ public class PromptHelper {
@Autowired
private ExemplarService exemplarService;
public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<SqlExemplar> exemplars = Lists.newArrayList();
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
llmReq.getDynamicExemplars().stream().forEach(e -> {
exemplars.add(e);
});
@@ -44,10 +44,10 @@ public class PromptHelper {
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
}
List<List<SqlExemplar>> results = new ArrayList<>();
List<List<Text2SQLExemplar>> results = new ArrayList<>();
// use random collection of exemplars for each self-consistency inference
for (int i = 0; i < selfConsistencyNumber; i++) {
List<SqlExemplar> shuffledList = new ArrayList<>(exemplars);
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
Collections.shuffle(shuffledList);
results.add(shuffledList.subList(0, fewShotNumber));
}

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.tuple.Pair;
@@ -54,7 +54,7 @@ public class ResponseHelper {
return Pair.of(inputMax, votePercentage);
}
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Text2SQLExemplar> sqlExamples,
Map<String, Double> sqlMap) {
if (sqlMap == null) {
return new HashMap<>();

View File

@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import lombok.Data;
@@ -30,7 +30,7 @@ public class LLMReq {
private ChatModelConfig modelConfig;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;
private List<Text2SQLExemplar> dynamicExemplars;
@Data

View File

@@ -1,6 +1,6 @@
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
@@ -16,6 +16,6 @@ public class LLMSqlResp {
private double sqlWeight;
private List<SqlExemplar> fewShots;
private List<Text2SQLExemplar> fewShots;
}