(improvement)(headless)Introduce side_information to the prompt and exemplar.

This commit is contained in:
jerryjzhang
2024-07-18 11:29:07 +08:00
parent f30c74c18f
commit 2eac301076
16 changed files with 128 additions and 165 deletions

View File

@@ -120,10 +120,10 @@ public class LLMRequestService {
public LLMResp runText2SQL(LLMReq llmReq) {
SqlGenStrategy sqlGenStrategy = SqlGenStrategyFactory.get(llmReq.getSqlGenType());
String modelName = llmReq.getSchema().getDataSetName();
String dataSet = llmReq.getSchema().getDataSetName();
LLMResp result = sqlGenStrategy.generate(llmReq);
result.setQuery(llmReq.getQueryText());
result.setModelName(modelName);
result.setDataSet(dataSet);
return result;
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.Constants;
import com.tencent.supersonic.common.jsqlparser.SqlValidHelper;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.query.QueryManager;
import com.tencent.supersonic.headless.chat.query.llm.LLMSemanticQuery;
@@ -35,6 +36,13 @@ public class LLMResponseService {
Map<String, Object> properties = new HashMap<>();
properties.put(Constants.CONTEXT, parseResult);
properties.put("type", "internal");
SqlExemplar exemplar = SqlExemplar.builder()
.question(queryCtx.getQueryText())
.sideInfo(parseResult.getLlmResp().getSideInfo())
.dbSchema(parseResult.getLlmResp().getSchema())
.sql(parseResult.getLlmResp().getSqlOutput())
.build();
properties.put(SqlExemplar.PROPERTY_KEY, exemplar);
parseInfo.setProperties(properties);
parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight));
parseInfo.setQueryMode(semanticQuery.getQueryMode());

View File

@@ -34,13 +34,15 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "1.ALWAYS use `数据日期` as the date field."
+ "2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator."
+ "3.ALWAYS calculate the absolute date range by yourself."
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the question."
+ "4.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`."
+ "5.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n{{exemplar}}"
+ "#Question:{{question}} #Schema:{{schema}} #SQL:";
+ "#Question:{{question}} #Schema:{{schema}} #SideInfo:{{information}} #SQL:";
@Override
public LLMResp generate(LLMReq llmReq) {
LLMResp llmResp = new LLMResp();
llmResp.setQuery(llmReq.getQueryText());
//1.recall exemplars
keyPipelineLog.info("OnePassSCSqlGenStrategy llmReq:\n{}", llmReq);
List<List<SqlExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
@@ -48,49 +50,51 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
//2.generate sql generation prompt for each self-consistency inference
Map<Prompt, List<SqlExemplar>> prompt2Exemplar = new HashMap<>();
for (List<SqlExemplar> exemplars : exemplarsList) {
Prompt prompt = generatePrompt(llmReq, exemplars);
llmReq.setDynamicExemplars(exemplars);
Prompt prompt = generatePrompt(llmReq, llmResp);
prompt2Exemplar.put(prompt, exemplars);
}
//3.perform multiple self-consistency inferences parallelly
Map<Prompt, String> prompt2Output = new ConcurrentHashMap<>();
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
prompt2Exemplar.keySet().parallelStream().forEach(prompt -> {
keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage());
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String result = response.content().text();
prompt2Output.put(prompt, result);
output2Prompt.put(result, prompt);
keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", result);
}
);
//4.format response.
Pair<String, Map<String, Double>> sqlMapPair = ResponseHelper.selfConsistencyVote(
Lists.newArrayList(prompt2Output.values()));
LLMResp llmResp = new LLMResp();
llmResp.setQuery(promptHelper.buildAugmentedQuestion(llmReq));
llmResp.setDbSchema(promptHelper.buildSchemaStr(llmReq));
Lists.newArrayList(output2Prompt.keySet()));
llmResp.setSqlOutput(sqlMapPair.getLeft());
//TODO: should use the same few-shot exemplars as the one chose by self-consistency vote
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(exemplarsList.get(0), sqlMapPair.getRight()));
List<SqlExemplar> usedExemplars = prompt2Exemplar.get(output2Prompt.get(sqlMapPair.getLeft()));
llmResp.setSqlRespMap(ResponseHelper.buildSqlRespMap(usedExemplars, sqlMapPair.getRight()));
return llmResp;
}
private Prompt generatePrompt(LLMReq llmReq, List<SqlExemplar> fewshotExampleList) {
StringBuilder exemplarsStr = new StringBuilder();
for (SqlExemplar exemplar : fewshotExampleList) {
String exemplarStr = String.format("#Question:%s #Schema:%s #SQL:%s\n",
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql());
exemplarsStr.append(exemplarStr);
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) {
StringBuilder exemplars = new StringBuilder();
for (SqlExemplar exemplar : llmReq.getDynamicExemplars()) {
String exemplarStr = String.format("#Question:%s #Schema:%s #SideInfo:%s #SQL:%s\n",
exemplar.getQuestion(), exemplar.getDbSchema(),
exemplar.getSideInfo(), exemplar.getSql());
exemplars.append(exemplarStr);
}
String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq);
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
String dataSemantics = promptHelper.buildSchemaStr(llmReq);
String sideInformation = promptHelper.buildSideInformation(llmReq);
llmResp.setSchema(dataSemantics);
llmResp.setSideInfo(sideInformation);
Map<String, Object> variable = new HashMap<>();
variable.put("exemplar", exemplarsStr);
variable.put("question", questionAugmented);
variable.put("schema", dataSemanticsStr);
variable.put("exemplar", exemplars);
variable.put("question", llmReq.getQueryText());
variable.put("schema", dataSemantics);
variable.put("information", sideInformation);
// use custom prompt template if provided.
PromptConfig promptConfig = llmReq.getPromptConfig();

View File

@@ -73,6 +73,23 @@ public class PromptHelper {
linkingListStr, currentDataStr, termStr, priorExts);
}
public String buildSideInformation(LLMReq llmReq) {
List<LLMReq.ElementValue> linkedValues = llmReq.getLinking();
String currentDate = llmReq.getCurrentDate();
String priorExts = llmReq.getPriorExts();
List<String> priorLinkingList = new ArrayList<>();
for (LLMReq.ElementValue value : linkedValues) {
String fieldName = value.getFieldName();
String fieldValue = value.getFieldValue();
priorLinkingList.add("" + fieldValue + "‘是一个‘" + fieldName + "");
}
String currentDataStr = "当前的日期是" + currentDate;
String linkingListStr = String.join("", priorLinkingList);
String termStr = buildTermStr(llmReq);
return String.format("%s;%s;%s;%s", linkingListStr, currentDataStr, termStr, priorExts);
}
public String buildSchemaStr(LLMReq llmReq) {
String tableStr = llmReq.getSchema().getDataSetName();
StringBuilder metricStr = new StringBuilder();

View File

@@ -10,9 +10,11 @@ public class LLMResp {
private String query;
private String modelName;
private String sideInfo;
private String dbSchema;
private String dataSet;
private String schema;
private String sqlOutput;