mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
(improvement)(headless)Introduce side_information to the prompt and exemplar.
This commit is contained in:
@@ -120,10 +120,10 @@ public class LLMRequestService {
|
||||
|
||||
public LLMResp runText2SQL(LLMReq llmReq) {
|
||||
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
|
||||
String modelName = llmReq.getSchema().getDataSetName();
|
||||
String dataSet = llmReq.getSchema().getDataSetName();
|
||||
LLMResp result = sqlGenStrategy.generate(llmReq);
|
||||
result.setQuery(llmReq.getQueryText());
|
||||
result.setModelName(modelName);
|
||||
result.setDataSet(dataSet);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.tencent.supersonic.common.pojo.Constants;
|
||||
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
|
||||
import com.tencent.supersonic.headless.chat.query.QueryManager;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
|
||||
@@ -35,6 +36,13 @@ public class LLMResponseService {
|
||||
Map<String, Object> properties = new HashMap<>();
|
||||
properties.put(Constants.CONTEXT, parseResult);
|
||||
properties.put("type", "internal");
|
||||
SqlExemplar exemplar = SqlExemplar.builder()
|
||||
.question(queryCtx.getQueryText())
|
||||
.sideInfo(parseResult.getLlmResp().getSideInfo())
|
||||
.dbSchema(parseResult.getLlmResp().getSchema())
|
||||
.sql(parseResult.getLlmResp().getSqlOutput())
|
||||
.build();
|
||||
properties.put(SqlExemplar.PROPERTY_KEY, exemplar);
|
||||
parseInfo.setProperties(properties);
|
||||
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
|
||||
parseInfo.setQueryMode(semanticQuery.getQueryMode());
|
||||
|
||||
@@ -34,13 +34,15 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "1.ALWAYS use `数据日期` as the date field."
|
||||
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
|
||||
+ "3.ALWAYS calculate the absolute date range by yourself."
|
||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the question."
|
||||
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
|
||||
+ "5.ONLY respond with the converted SQL statement.\n"
|
||||
+ "#Exemplars:\n{{exemplar}}"
|
||||
+ "#Question:{{question}} #Schema:{{schema}} #SQL:";
|
||||
+ "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:";
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(llmReq.getQueryText());
|
||||
//1.recall exemplars
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
|
||||
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
|
||||
@@ -48,49 +50,51 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
//2.generate sql generation prompt for each self-consistency inference
|
||||
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
|
||||
for (List<SqlExemplar> exemplars : exemplarsList) {
|
||||
Prompt prompt = generatePrompt(llmReq, exemplars);
|
||||
llmReq.setDynamicExemplars(exemplars);
|
||||
Prompt prompt = generatePrompt(llmReq, llmResp);
|
||||
prompt2Exemplar.put(prompt, exemplars);
|
||||
}
|
||||
|
||||
//3.perform multiple self-consistency inferences parallelly
|
||||
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
|
||||
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
|
||||
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
|
||||
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String result = response.content().text();
|
||||
prompt2Output.put(prompt, result);
|
||||
output2Prompt.put(result, prompt);
|
||||
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
|
||||
}
|
||||
);
|
||||
|
||||
//4.format response.
|
||||
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
|
||||
Lists.newArrayList(prompt2Output.values()));
|
||||
LLMResp llmResp = new LLMResp();
|
||||
llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq));
|
||||
llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq));
|
||||
Lists.newArrayList(output2Prompt.keySet()));
|
||||
llmResp.setSqlOutput(sqlMapPair.getLeft());
|
||||
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
|
||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
|
||||
List<SqlExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
|
||||
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
|
||||
|
||||
return llmResp;
|
||||
}
|
||||
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<SqlExemplar> fewshotExampleList) {
|
||||
StringBuilder exemplarsStr = new StringBuilder();
|
||||
for (SqlExemplar exemplar : fewshotExampleList) {
|
||||
String exemplarStr = String.format("#Question:%s #Schema:%s #SQL:%s\n",
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql());
|
||||
exemplarsStr.append(exemplarStr);
|
||||
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
|
||||
StringBuilder exemplars = new StringBuilder();
|
||||
for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) {
|
||||
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(),
|
||||
exemplar.getSideInfo(), exemplar.getSql());
|
||||
exemplars.append(exemplarStr);
|
||||
}
|
||||
String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq);
|
||||
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
|
||||
String dataSemantics = promptHelper.buildSchemaStr(llmReq);
|
||||
String sideInformation = promptHelper.buildSideInformation(llmReq);
|
||||
llmResp.setSchema(dataSemantics);
|
||||
llmResp.setSideInfo(sideInformation);
|
||||
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("exemplar", exemplarsStr);
|
||||
variable.put("question", questionAugmented);
|
||||
variable.put("schema", dataSemanticsStr);
|
||||
variable.put("exemplar", exemplars);
|
||||
variable.put("question", llmReq.getQueryText());
|
||||
variable.put("schema", dataSemantics);
|
||||
variable.put("information", sideInformation);
|
||||
|
||||
// use custom prompt template if provided.
|
||||
PromptConfig promptConfig = llmReq.getPromptConfig();
|
||||
|
||||
@@ -73,6 +73,23 @@ public class PromptHelper {
|
||||
linkingListStr, currentDataStr, termStr, priorExts);
|
||||
}
|
||||
|
||||
public String buildSideInformation(LLMReq llmReq) {
|
||||
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
|
||||
String currentDate = llmReq.getCurrentDate();
|
||||
String priorExts = llmReq.getPriorExts();
|
||||
|
||||
List<String> priorLinkingList = new ArrayList<>();
|
||||
for (LLMReq.ElementValue value : linkedValues) {
|
||||
String fieldName = value.getFieldName();
|
||||
String fieldValue = value.getFieldValue();
|
||||
priorLinkingList.add("‘" + fieldValue + "‘是一个‘" + fieldName + "‘");
|
||||
}
|
||||
String currentDataStr = "当前的日期是" + currentDate;
|
||||
String linkingListStr = String.join(",", priorLinkingList);
|
||||
String termStr = buildTermStr(llmReq);
|
||||
return String.format("%s;%s;%s;%s", linkingListStr, currentDataStr, termStr, priorExts);
|
||||
}
|
||||
|
||||
public String buildSchemaStr(LLMReq llmReq) {
|
||||
String tableStr = llmReq.getSchema().getDataSetName();
|
||||
StringBuilder metricStr = new StringBuilder();
|
||||
|
||||
@@ -10,9 +10,11 @@ public class LLMResp {
|
||||
|
||||
private String query;
|
||||
|
||||
private String modelName;
|
||||
private String sideInfo;
|
||||
|
||||
private String dbSchema;
|
||||
private String dataSet;
|
||||
|
||||
private String schema;
|
||||
|
||||
private String sqlOutput;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user