mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 21:17:08 +00:00
(improvement)(common)Rename SqlExemplar to Text2SQLExemplar.
This commit is contained in:
@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
|
||||
@@ -52,7 +52,7 @@ public class ChatQueryContext {
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
private SemanticParseInfo contextParseInfo;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
@@ -36,13 +36,13 @@ public class LLMResponseService {
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
SqlExemplar exemplar = SqlExemplar.builder()
|
||||
Text2SQLExemplar exemplar = Text2SQLExemplar.builder()
|
||||
.question(queryCtx.getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput())
|
||||
.build();
|
||||
properties.put(SqlExemplar.PROPERTY_KEY, exemplar);
|
||||
properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar);
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
|
||||
@@ -2,7 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -45,11 +45,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<SqlExemplar> exemplars : exemplarsList) {
|
||||
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<Text2SQLExemplar> exemplars : exemplarsList) {
|
||||
llmReq.setDynamicExemplars(exemplars);
|
||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
@@ -61,9 +61,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String result = response.content().text();
|
||||
output2Prompt.put(result, prompt);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||
String sqlOutput = StringUtils.normalizeSpace(response.content().text());
|
||||
output2Prompt.put(sqlOutput, prompt);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", sqlOutput);
|
||||
}
|
||||
);
|
||||
|
||||
@@ -71,7 +71,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
||||
Lists.newArrayList(output2Prompt.keySet()));
|
||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||
List<SqlExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
List<Text2SQLExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
@@ -79,7 +79,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
||||
StringBuilder exemplars = new StringBuilder();
|
||||
for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(),
|
||||
exemplar.getSideInfo(), exemplar.getSql());
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.service.ExemplarService;
|
||||
import com.tencent.supersonic.headless.chat.parser.ParserConfig;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
@@ -29,12 +29,12 @@ public class PromptHelper {
|
||||
@Autowired
|
||||
private ExemplarService exemplarService;
|
||||
|
||||
public List<List<SqlExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||
public List<List<Text2SQLExemplar>> getFewShotExemplars(LLMReq llmReq) {
|
||||
int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
|
||||
int fewShotNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_FEW_SHOT_NUMBER));
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||
List<Text2SQLExemplar> exemplars = Lists.newArrayList();
|
||||
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
||||
exemplars.add(e);
|
||||
});
|
||||
@@ -44,10 +44,10 @@ public class PromptHelper {
|
||||
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
||||
}
|
||||
|
||||
List<List<SqlExemplar>> results = new ArrayList<>();
|
||||
List<List<Text2SQLExemplar>> results = new ArrayList<>();
|
||||
// use random collection of exemplars for each self-consistency inference
|
||||
for (int i = 0; i < selfConsistencyNumber; i++) {
|
||||
List<SqlExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
List<Text2SQLExemplar> shuffledList = new ArrayList<>(exemplars);
|
||||
Collections.shuffle(shuffledList);
|
||||
results.add(shuffledList.subList(0, fewShotNumber));
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMSqlResp;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
@@ -54,7 +54,7 @@ public class ResponseHelper {
|
||||
return Pair.of(inputMax, votePercentage);
|
||||
}
|
||||
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<SqlExemplar> sqlExamples,
|
||||
public static Map<String, LLMSqlResp> buildSqlRespMap(List<Text2SQLExemplar> sqlExamples,
|
||||
Map<String, Double> sqlMap) {
|
||||
if (sqlMap == null) {
|
||||
return new HashMap<>();
|
||||
|
||||
@@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import lombok.Data;
|
||||
|
||||
@@ -30,7 +30,7 @@ public class LLMReq {
|
||||
private ChatModelConfig modelConfig;
|
||||
private PromptConfig promptConfig;
|
||||
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
private List<Text2SQLExemplar> dynamicExemplars;
|
||||
|
||||
|
||||
@Data
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
@@ -16,6 +16,6 @@ public class LLMSqlResp {
|
||||
|
||||
private double sqlWeight;
|
||||
|
||||
private List<SqlExemplar> fewShots;
|
||||
private List<Text2SQLExemplar> fewShots;
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user