mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
# This is a combination of 2 commits.
(feature)(headless)Support custom prompt template. #1348.
This commit is contained in:
@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat;
|
||||
|
||||
import com.fasterxml.jackson.annotation.JsonIgnore;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
@@ -50,7 +51,8 @@ public class QueryContext {
|
||||
private WorkflowState workflowState;
|
||||
private QueryDataType queryDataType = QueryDataType.ALL;
|
||||
private LLMConfig llmConfig;
|
||||
private List<SqlExemplar> exemplars;
|
||||
private PromptConfig promptConfig;
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
|
||||
public List<SemanticQuery> getCandidateQueries() {
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
|
||||
@@ -102,8 +102,8 @@ public class LLMRequestService {
|
||||
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
|
||||
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
|
||||
llmReq.setLlmConfig(queryCtx.getLlmConfig());
|
||||
|
||||
llmReq.setExemplars(queryCtx.getExemplars());
|
||||
llmReq.setPromptConfig(queryCtx.getPromptConfig());
|
||||
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
|
||||
|
||||
return llmReq;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.tencent.supersonic.headless.chat.parser.llm;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
|
||||
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
|
||||
@@ -10,10 +11,10 @@ import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
import dev.langchain4j.model.output.Response;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@@ -34,10 +35,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
+ "2.ALWAYS use `datediff()` as the date function."
|
||||
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
|
||||
+ "4.ONLY respond with the converted SQL statement.\n"
|
||||
+ "#Exemplars:\n%s"
|
||||
+ "#UserQuery: %s "
|
||||
+ "#Schema: %s "
|
||||
+ "#SQL: ";
|
||||
+ "#Exemplars:\n{{exemplar}}"
|
||||
+ "#Question:{{question}} #Schema:{{schema}} #SQL:";
|
||||
|
||||
@Override
|
||||
public LLMResp generate(LLMReq llmReq) {
|
||||
@@ -80,16 +79,25 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
|
||||
private Prompt generatePrompt(LLMReq llmReq, List<SqlExemplar> fewshotExampleList) {
|
||||
StringBuilder exemplarsStr = new StringBuilder();
|
||||
for (SqlExemplar exemplar : fewshotExampleList) {
|
||||
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
|
||||
String exemplarStr = String.format("#Question:%s #Schema:%s #SQL:%s\n",
|
||||
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql());
|
||||
exemplarsStr.append(exemplarStr);
|
||||
}
|
||||
|
||||
String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq);
|
||||
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
|
||||
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
|
||||
|
||||
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
Map<String, Object> variable = new HashMap<>();
|
||||
variable.put("exemplar", exemplarsStr);
|
||||
variable.put("question", questionAugmented);
|
||||
variable.put("schema", dataSemanticsStr);
|
||||
|
||||
// use custom prompt template if provided.
|
||||
PromptConfig promptConfig = llmReq.getPromptConfig();
|
||||
String prompTemplate = INSTRUCTION;
|
||||
if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) {
|
||||
prompTemplate = promptConfig.getPromptTemplate();
|
||||
}
|
||||
return PromptTemplate.from(prompTemplate).apply(variable);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -35,11 +35,11 @@ public class PromptHelper {
|
||||
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
|
||||
|
||||
List<SqlExemplar> exemplars = Lists.newArrayList();
|
||||
llmReq.getExemplars().stream().forEach(e -> {
|
||||
llmReq.getDynamicExemplars().stream().forEach(e -> {
|
||||
exemplars.add(e);
|
||||
});
|
||||
|
||||
int recallSize = exemplarRecallNumber - llmReq.getExemplars().size();
|
||||
int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
|
||||
if (recallSize > 0) {
|
||||
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
|
||||
import com.fasterxml.jackson.annotation.JsonValue;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.common.config.LLMConfig;
|
||||
import com.tencent.supersonic.common.config.PromptConfig;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.common.pojo.SqlExemplar;
|
||||
import lombok.Data;
|
||||
@@ -28,7 +29,10 @@ public class LLMReq {
|
||||
|
||||
private LLMConfig llmConfig;
|
||||
|
||||
private List<SqlExemplar> exemplars;
|
||||
private PromptConfig promptConfig;
|
||||
|
||||
private List<SqlExemplar> dynamicExemplars;
|
||||
|
||||
|
||||
@Data
|
||||
public static class ElementValue {
|
||||
|
||||
Reference in New Issue
Block a user