# This is a combination of 2 commits.

(feature)(headless)Support custom prompt template. #1348.
This commit is contained in:
jerryjzhang
2024-07-05 15:13:31 +08:00
parent 097f2f4fe7
commit 72465cd88c
15 changed files with 180 additions and 138 deletions

View File

@@ -5,6 +5,7 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.pojo.RecordInfo;
import lombok.Data;
@@ -33,6 +34,7 @@ public class Agent extends RecordInfo {
private List<String> examples;
private String agentConfig;
private LLMConfig llmConfig;
private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig;

View File

@@ -66,10 +66,10 @@ public class NL2SQLParser implements ChatParser {
if (!chatParseContext.enableNL2SQL() || checkSkip(parseResp)) {
return;
}
processMultiTurn(chatParseContext);
processMultiTurn(chatParseContext);
QueryReq queryReq = QueryReqConverter.buildText2SqlQueryReq(chatParseContext);
addExemplars(chatParseContext.getAgent().getId(), queryReq);
addDynamicExemplars(chatParseContext.getAgent().getId(), queryReq);
ChatQueryService chatQueryService = ContextUtils.getBean(ChatQueryService.class);
ParseResp text2SqlParseResp = chatQueryService.performParsing(queryReq);
@@ -224,13 +224,13 @@ public class NL2SQLParser implements ChatParser {
return contextualList;
}
private void addExemplars(Integer agentId, QueryReq queryReq) {
private void addDynamicExemplars(Integer agentId, QueryReq queryReq) {
ExemplarServiceImpl exemplarManager = ContextUtils.getBean(ExemplarServiceImpl.class);
EmbeddingConfig embeddingConfig = ContextUtils.getBean(EmbeddingConfig.class);
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<SqlExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
queryReq.getQueryText(), 5);
queryReq.getExemplars().addAll(exemplars);
queryReq.getDynamicExemplars().addAll(exemplars);
}
@Builder

View File

@@ -3,7 +3,6 @@ package com.tencent.supersonic.chat.server.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.tencent.supersonic.common.config.VisualConfig;
import lombok.Data;
import java.util.Date;
@@ -63,4 +62,6 @@ public class AgentDO {
private String visualConfig;
private String promptConfig;
}

View File

@@ -13,6 +13,7 @@ import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.config.VisualConfig;
import com.tencent.supersonic.common.util.JsonUtil;
import lombok.extern.slf4j.Slf4j;
@@ -121,6 +122,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
agent.setAgentConfig(agentDO.getConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setLlmConfig(JsonUtil.toObject(agentDO.getLlmConfig(), LLMConfig.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
return agent;
@@ -134,6 +136,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO>
agentDO.setLlmConfig(JsonUtil.toString(agent.getLlmConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
if (agentDO.getStatus() == null) {
agentDO.setStatus(1);
}

View File

@@ -31,6 +31,7 @@ public class QueryReqConverter {
queryReq.setMapInfo(queryReq.getMapInfo());
}
queryReq.setLlmConfig(agent.getLlmConfig());
queryReq.setPromptConfig(agent.getPromptConfig());
return queryReq;
}

View File

@@ -0,0 +1,14 @@
package com.tencent.supersonic.common.config;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class PromptConfig {
private String promptTemplate;
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.headless.api.pojo.QueryDataType;
@@ -27,5 +28,6 @@ public class QueryReq {
private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
private List<SqlExemplar> exemplars = Lists.newArrayList();
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars = Lists.newArrayList();
}

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.common.config.LLMConfig;
@@ -50,7 +51,8 @@ public class QueryContext {
private WorkflowState workflowState;
private QueryDataType queryDataType = QueryDataType.ALL;
private LLMConfig llmConfig;
private List<SqlExemplar> exemplars;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;
public List<SemanticQuery> getCandidateQueries() {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);

View File

@@ -102,8 +102,8 @@ public class LLMRequestService {
llmReq.setCurrentDate(DateUtils.getBeforeDate(0));
llmReq.setSqlGenType(LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setLlmConfig(queryCtx.getLlmConfig());
llmReq.setExemplars(queryCtx.getExemplars());
llmReq.setPromptConfig(queryCtx.getPromptConfig());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
return llmReq;
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
@@ -10,10 +11,10 @@ import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -34,10 +35,8 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "2.ALWAYS use `datediff()` as the date function."
+ "3.DO NOT specify date filter in the where clause if not explicitly mentioned in the query."
+ "4.ONLY respond with the converted SQL statement.\n"
+ "#Exemplars:\n%s"
+ "#UserQuery: %s "
+ "#Schema: %s "
+ "#SQL: ";
+ "#Exemplars:\n{{exemplar}}"
+ "#Question:{{question}} #Schema:{{schema}} #SQL:";
@Override
public LLMResp generate(LLMReq llmReq) {
@@ -80,16 +79,25 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
private Prompt generatePrompt(LLMReq llmReq, List<SqlExemplar> fewshotExampleList) {
StringBuilder exemplarsStr = new StringBuilder();
for (SqlExemplar exemplar : fewshotExampleList) {
String exemplarStr = String.format("#UserQuery: %s #Schema: %s #SQL: %s\n",
String exemplarStr = String.format("#Question:%s #Schema:%s #SQL:%s\n",
exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSql());
exemplarsStr.append(exemplarStr);
}
String dataSemanticsStr = promptHelper.buildSchemaStr(llmReq);
String questionAugmented = promptHelper.buildAugmentedQuestion(llmReq);
String promptStr = String.format(INSTRUCTION, exemplarsStr, questionAugmented, dataSemanticsStr);
return PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
Map<String, Object> variable = new HashMap<>();
variable.put("exemplar", exemplarsStr);
variable.put("question", questionAugmented);
variable.put("schema", dataSemanticsStr);
// use custom prompt template if provided.
PromptConfig promptConfig = llmReq.getPromptConfig();
String prompTemplate = INSTRUCTION;
if (promptConfig != null && StringUtils.isNotBlank(promptConfig.getPromptTemplate())) {
prompTemplate = promptConfig.getPromptTemplate();
}
return PromptTemplate.from(prompTemplate).apply(variable);
}
@Override

View File

@@ -35,11 +35,11 @@ public class PromptHelper {
int selfConsistencyNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER));
List<SqlExemplar> exemplars = Lists.newArrayList();
llmReq.getExemplars().stream().forEach(e -> {
llmReq.getDynamicExemplars().stream().forEach(e -> {
exemplars.add(e);
});
int recallSize = exemplarRecallNumber - llmReq.getExemplars().size();
int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size();
if (recallSize > 0) {
exemplars.addAll(exemplarService.recallExemplars(llmReq.getQueryText(), recallSize));
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.config.LLMConfig;
import com.tencent.supersonic.common.config.PromptConfig;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.common.pojo.SqlExemplar;
import lombok.Data;
@@ -28,7 +29,10 @@ public class LLMReq {
private LLMConfig llmConfig;
private List<SqlExemplar> exemplars;
private PromptConfig promptConfig;
private List<SqlExemplar> dynamicExemplars;
@Data
public static class ElementValue {

View File

@@ -310,7 +310,7 @@ CREATE TABLE IF NOT EXISTS `s2_term` (
);
--20240520
alter table s2_agent add column `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
alter table s2_agent add column `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '大模型配置';
alter table s2_agent add column `multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL;
alter table s2_model add column `ext` varchar(1000) DEFAULT NULL;
@@ -348,3 +348,6 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
`updated_by` varchar(100) NOT NULL ,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;
--20240705
alter table s2_agent add column `prompt_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL COMMENT '提示词配置';

View File

@@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
examples varchar(500) null,
config varchar(2000) null,
llm_config varchar(2000) null,
prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null,
created_by varchar(100) null,

View File

@@ -374,6 +374,7 @@ CREATE TABLE IF NOT EXISTS s2_agent
examples varchar(500) null,
config varchar(2000) null,
llm_config varchar(2000) null,
prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null,
created_by varchar(100) null,