[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739

This commit is contained in:
jerryjzhang
2024-10-12 11:51:37 +08:00
parent 4408bf4325
commit 2717a1603c
25 changed files with 220 additions and 175 deletions

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.agent;
import com.alibaba.fastjson.JSONObject; import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.RecordInfo; import com.tencent.supersonic.common.pojo.RecordInfo;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import lombok.Data; import lombok.Data;
@@ -24,6 +25,7 @@ public class Agent extends RecordInfo {
private Integer enableMemoryReview; private Integer enableMemoryReview;
private String toolConfig; private String toolConfig;
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP; private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
private PromptConfig promptConfig; private PromptConfig promptConfig;
private MultiTurnConfig multiTurnConfig; private MultiTurnConfig multiTurnConfig;
private VisualConfig visualConfig; private VisualConfig visualConfig;

View File

@@ -8,8 +8,8 @@ import com.tencent.supersonic.chat.server.parser.ParserConfig;
import com.tencent.supersonic.chat.server.pojo.ExecuteContext; import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService; import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.response.QueryState; import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.AiMessage;
@@ -28,26 +28,35 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT
public class PlainTextExecutor implements ChatQueryExecutor { public class PlainTextExecutor implements ChatQueryExecutor {
private static final String APP_KEY = "SMALL_TALK";
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n" private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n"
+ "#Task: Respond quickly and nicely to the user." + "#Task: Respond quickly and nicely to the user."
+ "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n" + "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n"
+ "#Current Input: %s\n" + "#Your response: "; + "#Current Input: %s\n" + "#Your response: ";
public PlainTextExecutor() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("闲聊对话")
.description("直接将原始输入透传大模型").enable(true).build());
}
@Override @Override
public QueryResult execute(ExecuteContext executeContext) { public QueryResult execute(ExecuteContext executeContext) {
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) { if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
return null; return null;
} }
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
executeContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId()); Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
if (!chatApp.isEnable()) {
return null;
}
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext),
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.RESPONSE_GENERATE)); executeContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(chatApp.getChatModelConfig());
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
QueryResult result = new QueryResult(); QueryResult result = new QueryResult();
@@ -60,25 +69,12 @@ public class PlainTextExecutor implements ChatQueryExecutor {
private String getHistoryInputs(ExecuteContext executeContext) { private String getHistoryInputs(ExecuteContext executeContext) {
StringBuilder historyInput = new StringBuilder(); StringBuilder historyInput = new StringBuilder();
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
queryResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");
AgentService agentService = ContextUtils.getBean(AgentService.class); });
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
Boolean globalMultiTurnConfig =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig =
agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn()
: globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) {
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
queryResps.stream().forEach(p -> {
historyInput.append(p.getQueryText());
historyInput.append(";");
});
}
return historyInput.toString(); return historyInput.toString();
} }

View File

@@ -6,7 +6,8 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import com.tencent.supersonic.chat.server.service.AgentService; import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.util.ChatAppManager;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt; import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate; import dev.langchain4j.model.input.PromptTemplate;
@@ -29,11 +30,11 @@ public class MemoryReviewTask {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY = "MEMORY_REVIEW";
private static final String INSTRUCTION = "" private static final String INSTRUCTION = ""
+ "\n#Role: You are a senior data engineer experienced in writing SQL." + "\n#Role: You are a senior data engineer experienced in writing SQL."
+ "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer," + "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer,"
+ "please take a review and give your opinion." + "please take a review and give your opinion." + "\n#Rules: "
+ "\n#Rules: "
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
+ "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s" + "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s"
@@ -47,6 +48,11 @@ public class MemoryReviewTask {
@Autowired @Autowired
private AgentService agentService; private AgentService agentService;
public MemoryReviewTask() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("记忆启用评估")
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build());
}
@Scheduled(fixedDelay = 60 * 1000) @Scheduled(fixedDelay = 60 * 1000)
public void review() { public void review() {
try { try {
@@ -58,16 +64,22 @@ public class MemoryReviewTask {
private void processMemory(ChatMemoryDO m) { private void processMemory(ChatMemoryDO m) {
Agent chatAgent = agentService.getAgent(m.getAgentId()); Agent chatAgent = agentService.getAgent(m.getAgentId());
if (Objects.isNull(chatAgent) || !chatAgent.enableMemoryReview()) { if (Objects.isNull(chatAgent)) {
log.debug("Agent id {} not found or memory review disabled", m.getAgentId()); log.warn("Agent id {} not found or memory review disabled", m.getAgentId());
return; return;
} }
String promptStr = createPromptString(m);
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
if (!chatApp.isEnable()) {
return;
}
String promptStr = createPromptString(m, chatApp.getPrompt());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr); keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel( ChatLanguageModel chatLanguageModel =
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.MEMORY_REVIEW)); ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
if (Objects.nonNull(chatLanguageModel)) { if (Objects.nonNull(chatLanguageModel)) {
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text(); String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response); keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
@@ -77,8 +89,8 @@ public class MemoryReviewTask {
} }
} }
private String createPromptString(ChatMemoryDO m) { private String createPromptString(ChatMemoryDO m, String promptTemplate) {
return String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
m.getS2sql()); m.getS2sql());
} }

View File

@@ -1,8 +1,6 @@
package com.tencent.supersonic.chat.server.parser; package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.api.pojo.response.QueryResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager; import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
@@ -11,10 +9,10 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.chat.server.util.QueryReqConverter; import com.tencent.supersonic.chat.server.util.QueryReqConverter;
import com.tencent.supersonic.common.config.EmbeddingConfig; import com.tencent.supersonic.common.config.EmbeddingConfig;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl; import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch; import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
@@ -47,7 +45,6 @@ import java.util.Optional;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER; import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
@Slf4j @Slf4j
@@ -55,6 +52,7 @@ public class NL2SQLParser implements ChatQueryParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String APP_KEY_MULTI_TURN = "REWRITE_MULTI_TURN";
private static final String REWRITE_MULTI_TURN_INSTRUCTION = "" private static final String REWRITE_MULTI_TURN_INSTRUCTION = ""
+ "#Role: You are a data product manager experienced in data requirements." + "#Role: You are a data product manager experienced in data requirements."
+ "#Task: Your will be provided with current and history questions asked by a user," + "#Task: Your will be provided with current and history questions asked by a user,"
@@ -68,6 +66,7 @@ public class NL2SQLParser implements ChatQueryParser {
+ "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}" + "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}"
+ "#Rewritten Question: "; + "#Rewritten Question: ";
public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE";
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = "" private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
+ "#Role: You are a data business partner who closely interacts with business people.\n" + "#Role: You are a data business partner who closely interacts with business people.\n"
+ "#Task: Your will be provided with user input, system output and some examples, " + "#Task: Your will be provided with user input, system output and some examples, "
@@ -77,6 +76,16 @@ public class NL2SQLParser implements ChatQueryParser {
+ "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n" + "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
+ "#Examples: {{examples}}\n" + "#Response: "; + "#Examples: {{examples}}\n" + "#Response: ";
public NL2SQLParser() {
ChatAppManager.register(
ChatApp.builder().key(APP_KEY_MULTI_TURN).prompt(REWRITE_MULTI_TURN_INSTRUCTION)
.name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build());
ChatAppManager.register(ChatApp.builder().key(APP_KEY_ERROR_MESSAGE)
.prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
.description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build());
}
@Override @Override
public void parse(ParseContext parseContext, ParseResp parseResp) { public void parse(ParseContext parseContext, ParseResp parseResp) {
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) { if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
@@ -97,10 +106,8 @@ public class NL2SQLParser implements ChatQueryParser {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else { } else {
if (!parseContext.isDisableLLM()) { if (!parseContext.isDisableLLM()) {
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(), parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig(
parseContext.getAgent(), ChatModelType.RESPONSE_GENERATE)));
} }
} }
parseResp.setState(text2SqlParseResp.getState()); parseResp.setState(text2SqlParseResp.getState());
@@ -162,22 +169,11 @@ public class NL2SQLParser implements ChatQueryParser {
} }
private void processMultiTurn(ParseContext parseContext) { private void processMultiTurn(ParseContext parseContext) {
Agent agent = parseContext.getAgent(); ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_MULTI_TURN);
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); if (!chatApp.isEnable()) {
MultiTurnConfig agentMultiTurnConfig = agent.getMultiTurnConfig();
Boolean globalMultiTurnConfig =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig =
agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn()
: globalMultiTurnConfig;
if (!Boolean.TRUE.equals(multiTurnConfig)) {
return; return;
} }
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.MULTI_TURN_REWRITE));
// derive mapping result of current question and parsing result of last question. // derive mapping result of current question and parsing result of last question.
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class); ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
@@ -203,9 +199,11 @@ public class NL2SQLParser implements ChatQueryParser {
variables.put("history_schema", histMapStr); variables.put("history_schema", histMapStr);
variables.put("history_sql", histSQL); variables.put("history_sql", histSQL);
Prompt prompt = PromptTemplate.from(REWRITE_MULTI_TURN_INSTRUCTION).apply(variables); Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text()); keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text());
ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenQuery = response.content().text(); String rewrittenQuery = response.content().text();
keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery); keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery);
@@ -217,24 +215,30 @@ public class NL2SQLParser implements ChatQueryParser {
currentMapResult.getQueryText(), rewrittenQuery); currentMapResult.getQueryText(), rewrittenQuery);
} }
private String rewriteErrorMessage(String userQuestion, String errMsg, private String rewriteErrorMessage(ParseContext parseContext, String errMsg,
List<Text2SQLExemplar> similarExemplars, List<String> agentExamples, List<Text2SQLExemplar> similarExemplars) {
ChatModelConfig modelConfig) {
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
if (!chatApp.isEnable()) {
return errMsg;
}
Map<String, Object> variables = new HashMap<>(); Map<String, Object> variables = new HashMap<>();
variables.put("user_question", userQuestion); variables.put("user_question", parseContext.getQueryText());
variables.put("system_message", errMsg); variables.put("system_message", errMsg);
StringBuilder exampleStr = new StringBuilder(); StringBuilder exampleStr = new StringBuilder();
similarExemplars.forEach(e -> exampleStr.append( similarExemplars.forEach(e -> exampleStr.append(
String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema()))); String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
agentExamples.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e))); parseContext.getAgent().getExamples()
.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
variables.put("examples", exampleStr); variables.put("examples", exampleStr);
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables); Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text()); keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text());
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig); ChatLanguageModel chatLanguageModel =
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage()); Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
String rewrittenMsg = response.content().text(); String rewrittenMsg = response.content().text();
keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg); keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg);

View File

@@ -33,15 +33,9 @@ public class AgentDO {
private Integer enableSearch; private Integer enableSearch;
private Integer enableMemoryReview;
private String toolConfig; private String toolConfig;
private String chatModelConfig; private String chatModelConfig;
private String multiTurnConfig;
private String visualConfig; private String visualConfig;
private String promptConfig;
} }

View File

@@ -10,12 +10,15 @@ import com.tencent.supersonic.chat.server.config.ChatModelParameters;
import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Parameter; import com.tencent.supersonic.common.pojo.Parameter;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.util.ChatAppManager;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -51,11 +54,9 @@ public class ChatModelController {
return chatModelService.getChatModels(); return chatModelService.getChatModels();
} }
@RequestMapping("/getModelTypeList") @RequestMapping("/getModelAppList")
public List<ChatModelTypeResp> getModelTypeList() { public List<ChatApp> getModelAppList() {
return Arrays.stream(ChatModelType.values()).map(t -> ChatModelTypeResp.builder() return new ArrayList(ChatAppManager.getAllApps().values());
.type(t.toString()).name(t.getName()).description(t.getDescription()).build())
.collect(Collectors.toList());
} }
@RequestMapping("/getModelParameters") @RequestMapping("/getModelParameters")

View File

@@ -15,8 +15,8 @@ import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.chat.server.service.ChatQueryService; import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.MemoryService; import com.tencent.supersonic.chat.server.service.MemoryService;
import com.tencent.supersonic.chat.server.util.ModelConfigHelper; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy; import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@@ -26,6 +26,7 @@ import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@@ -53,12 +54,6 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
@Override @Override
public Agent createAgent(Agent agent, User user) { public Agent createAgent(Agent agent, User user) {
if (Objects.isNull(agent.getPromptConfig())
|| Objects.isNull(agent.getPromptConfig().getPromptTemplate())) {
PromptConfig promptConfig = new PromptConfig();
promptConfig.setPromptTemplate(OnePassSCSqlGenStrategy.INSTRUCTION.trim());
agent.setPromptConfig(promptConfig);
}
agent.createdBy(user.getName()); agent.createdBy(user.getName());
AgentDO agentDO = convert(agent); AgentDO agentDO = convert(agent);
save(agentDO); save(agentDO);
@@ -69,12 +64,6 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
@Override @Override
public Agent updateAgent(Agent agent, User user) { public Agent updateAgent(Agent agent, User user) {
if (Objects.isNull(agent.getPromptConfig())
|| Objects.isNull(agent.getPromptConfig().getPromptTemplate())) {
PromptConfig promptConfig = new PromptConfig();
promptConfig.setPromptTemplate(OnePassSCSqlGenStrategy.INSTRUCTION.trim());
agent.setPromptConfig(promptConfig);
}
agent.updatedBy(user.getName()); agent.updatedBy(user.getName());
updateById(convert(agent)); updateById(convert(agent));
executeAgentExamplesAsync(agent); executeAgentExamplesAsync(agent);
@@ -105,10 +94,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
} }
private synchronized void doExecuteAgentExamples(Agent agent) { private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsDatasetTool() if (!agent.containsDatasetTool() || !agent.enableMemoryReview()
|| !agent.enableMemoryReview()
|| !ModelConfigHelper.testConnection(
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL))
|| CollectionUtils.isEmpty(agent.getExamples())) { || CollectionUtils.isEmpty(agent.getExamples())) {
return; return;
} }
@@ -144,11 +130,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
BeanUtils.copyProperties(agentDO, agent); BeanUtils.copyProperties(agentDO, agent);
agent.setToolConfig(agentDO.getToolConfig()); agent.setToolConfig(agentDO.getToolConfig());
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class)); agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
agent.setChatModelConfig( agent.setChatAppConfig(
JsonUtil.toMap(agentDO.getChatModelConfig(), ChatModelType.class, Integer.class)); JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
agent.setMultiTurnConfig(
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class)); agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
return agent; return agent;
} }
@@ -158,10 +141,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
BeanUtils.copyProperties(agent, agentDO); BeanUtils.copyProperties(agent, agentDO);
agentDO.setToolConfig(agent.getToolConfig()); agentDO.setToolConfig(agent.getToolConfig());
agentDO.setExamples(JsonUtil.toString(agent.getExamples())); agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatModelConfig())); agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig()));
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig())); agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
if (agentDO.getStatus() == null) { if (agentDO.getStatus() == null) {
agentDO.setStatus(1); agentDO.setStatus(1);
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.service.ChatModelService; import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
@@ -27,13 +28,10 @@ public class ModelConfigHelper {
} }
} }
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) { public static ChatModelConfig getChatModelConfig(ChatApp chatApp) {
ChatModelConfig chatModelConfig = null; ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
if (agent.getChatModelConfig().containsKey(modelType)) { ChatModelConfig chatModelConfig =
Integer chatModelId = agent.getChatModelConfig().get(modelType); chatModelService.getChatModel(chatApp.getChatModelId()).getConfig();
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
}
return chatModelConfig; return chatModelConfig;
} }
} }

View File

@@ -3,10 +3,12 @@ package com.tencent.supersonic.chat.server.util;
import com.tencent.supersonic.chat.server.agent.Agent; import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.pojo.ChatContext; import com.tencent.supersonic.chat.server.pojo.ChatContext;
import com.tencent.supersonic.chat.server.pojo.ParseContext; import com.tencent.supersonic.chat.server.pojo.ParseContext;
import com.tencent.supersonic.chat.server.service.ChatModelService;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.enums.ChatModelType;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq; import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
import org.apache.commons.collections.MapUtils; import org.apache.commons.collections.MapUtils;
@@ -37,10 +39,7 @@ public class QueryReqConverter {
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) { && MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
queryNLReq.setMapInfo(queryNLReq.getMapInfo()); queryNLReq.setMapInfo(queryNLReq.getMapInfo());
} }
ChatModelConfig chatModelConfig = queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
queryNLReq.setModelConfig(chatModelConfig);
queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
if (chatCtx != null) { if (chatCtx != null) {
queryNLReq.setContextParseInfo(chatCtx.getParseInfo()); queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
} }

View File

@@ -0,0 +1,22 @@
package com.tencent.supersonic.common.pojo;
import com.fasterxml.jackson.annotation.JsonIgnore;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@Builder
@AllArgsConstructor
@NoArgsConstructor
public class ChatApp {
private String key;
private String name;
private String description;
private String prompt;
private boolean enable;
private Integer chatModelId;
@JsonIgnore
private ChatModelConfig chatModelConfig;
}

View File

@@ -0,0 +1,23 @@
package com.tencent.supersonic.common.util;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.common.pojo.ChatApp;
import java.util.List;
import java.util.Map;
public class ChatAppManager {
private static final Map<String, ChatApp> chatApps = Maps.newConcurrentMap();
public static void register(ChatApp chatApp) {
if (chatApps.containsKey(chatApp.getKey())) {
throw new RuntimeException("Duplicate chat app key is disallowed.");
}
chatApps.put(chatApp.getKey(), chatApp);
}
public static Map<String, ChatApp> getAllApps() {
return chatApps;
}
}

View File

@@ -3,6 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.request;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.google.common.collect.Sets; import com.google.common.collect.Sets;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -13,6 +14,7 @@ import com.tencent.supersonic.headless.api.pojo.enums.MapModeEnum;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
@Data @Data
@@ -26,8 +28,7 @@ public class QueryNLReq {
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private QueryDataType queryDataType = QueryDataType.ALL; private QueryDataType queryDataType = QueryDataType.ALL;
private ChatModelConfig modelConfig; private Map<String, ChatApp> chatAppConfig;
private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList(); private List<Text2SQLExemplar> dynamicExemplars = Lists.newArrayList();
private SemanticParseInfo contextParseInfo; private SemanticParseInfo contextParseInfo;
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.pojo.enums.Text2SQLType; import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
@@ -47,13 +48,14 @@ public class ChatQueryContext {
private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SchemaMapInfo mapInfo = new SchemaMapInfo();
private SemanticParseInfo contextParseInfo; private SemanticParseInfo contextParseInfo;
private MapModeEnum mapModeEnum = MapModeEnum.STRICT; private MapModeEnum mapModeEnum = MapModeEnum.STRICT;
private QueryDataType queryDataType = QueryDataType.ALL;
@JsonIgnore @JsonIgnore
private SemanticSchema semanticSchema; private SemanticSchema semanticSchema;
@JsonIgnore @JsonIgnore
private ChatWorkflowState chatWorkflowState; private ChatWorkflowState chatWorkflowState;
private QueryDataType queryDataType = QueryDataType.ALL; @JsonIgnore
private ChatModelConfig modelConfig; private Map<String, ChatApp> chatAppConfig;
private String customPrompt; @JsonIgnore
private List<Text2SQLExemplar> dynamicExemplars; private List<Text2SQLExemplar> dynamicExemplars;
public List<SemanticQuery> getCandidateQueries() { public List<SemanticQuery> getCandidateQueries() {

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.headless.chat.corrector; package com.tencent.supersonic.headless.chat.corrector;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo; import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -23,11 +25,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
public static final String INSTRUCTION = "" public static final String APP_KEY = "S2SQL_CORRECTOR";
private static final String INSTRUCTION = ""
+ "\n#Role: You are a senior data engineer experienced in writing SQL." + "\n#Role: You are a senior data engineer experienced in writing SQL."
+ "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer," + "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer,"
+ "please take a review and help correct it if necessary." + "please take a review and help correct it if necessary." + "\n#Rules: "
+ "\n#Rules: "
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`." + "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),sql=(corrected sql if NEGATIVE; empty string if POSITIVE)`."
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard." + "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
+ "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + "\n3.DO NOT miss the AGGREGATE operator of metrics, always add it as needed."
@@ -36,6 +38,11 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
+ "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`." + "\n6.ALWAYS translate alias created by `AS` command to the same language as the `#Question`."
+ "\n#Question:{{question}} #InputSQL:{{sql}} #Response:"; + "\n#Question:{{question}} #InputSQL:{{sql}} #Response:";
public LLMSqlCorrector() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL修正")
.description("").enable(false).build());
}
@Data @Data
@ToString @ToString
static class SemanticSql { static class SemanticSql {
@@ -52,14 +59,16 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
@Override @Override
public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { public void doCorrect(ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) {
if (!chatQueryContext.getText2SQLType().enableLLM()) { ChatApp chatApp = chatQueryContext.getChatAppConfig().get(APP_KEY);
if (!chatQueryContext.getText2SQLType().enableLLM() || !chatApp.isEnable()) {
return; return;
} }
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatQueryContext.getModelConfig()); ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor = SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel); AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo); Prompt prompt = generatePrompt(chatQueryContext.getQueryText(), semanticParseInfo,
chatApp.getPrompt());
keyPipelineLog.info("LLMSqlCorrector reqPrompt:\n{}", prompt.text()); keyPipelineLog.info("LLMSqlCorrector reqPrompt:\n{}", prompt.text());
SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText());
keyPipelineLog.info("LLMSqlCorrector modelResp:\n{}", s2Sql); keyPipelineLog.info("LLMSqlCorrector modelResp:\n{}", s2Sql);
@@ -68,12 +77,12 @@ public class LLMSqlCorrector extends BaseSemanticCorrector {
} }
} }
private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo) { private Prompt generatePrompt(String queryText, SemanticParseInfo semanticParseInfo,
String promptTemplate) {
Map<String, Object> variable = new HashMap<>(); Map<String, Object> variable = new HashMap<>();
variable.put("question", queryText); variable.put("question", queryText);
variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); variable.put("sql", semanticParseInfo.getSqlInfo().getCorrectedS2SQL());
String promptTemplate = INSTRUCTION;
return PromptTemplate.from(promptTemplate).apply(variable); return PromptTemplate.from(promptTemplate).apply(variable);
} }
} }

View File

@@ -74,8 +74,7 @@ public class LLMRequestService {
llmReq.setTerms(getMappedTerms(queryCtx, dataSetId)); llmReq.setTerms(getMappedTerms(queryCtx, dataSetId));
llmReq.setSqlGenType( llmReq.setSqlGenType(
LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE))); LLMReq.SqlGenType.valueOf(parserConfig.getParameterValue(PARSER_STRATEGY_TYPE)));
llmReq.setModelConfig(queryCtx.getModelConfig()); llmReq.setChatAppConfig(queryCtx.getChatAppConfig());
llmReq.setCustomPrompt(queryCtx.getCustomPrompt());
llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars()); llmReq.setDynamicExemplars(queryCtx.getDynamicExemplars());
return llmReq; return llmReq;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.ContextUtils;
import com.tencent.supersonic.headless.chat.ChatQueryContext; import com.tencent.supersonic.headless.chat.ChatQueryContext;
import com.tencent.supersonic.headless.chat.parser.SemanticParser; import com.tencent.supersonic.headless.chat.parser.SemanticParser;
@@ -69,10 +70,12 @@ public class LLMSqlParser implements SemanticParser {
} catch (Exception e) { } catch (Exception e) {
log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e); log.error("currentRetryRound:{}, runText2SQL failed", currentRetry, e);
} }
Double temperature = llmReq.getModelConfig().getTemperature(); ChatModelConfig chatModelConfig = llmReq.getChatAppConfig()
.get(OnePassSCSqlGenStrategy.APP_KEY).getChatModelConfig();
Double temperature = chatModelConfig.getTemperature();
if (temperature == 0) { if (temperature == 0) {
// 报错时增加随机性,减少无效重试 // 报错时增加随机性,减少无效重试
llmReq.getModelConfig().setTemperature(0.5); chatModelConfig.setTemperature(0.5);
} }
currentRetry++; currentRetry++;
} }

View File

@@ -1,7 +1,9 @@
package com.tencent.supersonic.headless.chat.parser.llm; package com.tencent.supersonic.headless.chat.parser.llm;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMReq;
import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp; import com.tencent.supersonic.headless.chat.query.llm.s2sql.LLMResp;
import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.chat.ChatLanguageModel;
@@ -11,7 +13,6 @@ import dev.langchain4j.model.output.structured.Description;
import dev.langchain4j.service.AiServices; import dev.langchain4j.service.AiServices;
import lombok.Data; import lombok.Data;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.Pair;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
@@ -24,6 +25,7 @@ import java.util.concurrent.ConcurrentHashMap;
@Slf4j @Slf4j
public class OnePassSCSqlGenStrategy extends SqlGenStrategy { public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
public static final String APP_KEY = "S2SQL_PARSER";
public static final String INSTRUCTION = "" public static final String INSTRUCTION = ""
+ "\n#Role: You are a data analyst experienced in SQL languages." + "\n#Role: You are a data analyst experienced in SQL languages."
+ "\n#Task: You will be provided with a natural language question asked by users," + "\n#Task: You will be provided with a natural language question asked by users,"
@@ -40,6 +42,11 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
+ "\n#Exemplars: {{exemplar}}" + "\n#Exemplars: {{exemplar}}"
+ "\n#Question: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + "\n#Question: Question:{{question}},Schema:{{schema}},SideInfo:{{information}}";
public OnePassSCSqlGenStrategy() {
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("语义SQL解析")
.description("通过大模型做语义解析生成S2SQL").enable(true).build());
}
@Data @Data
static class SemanticSql { static class SemanticSql {
@Description("thought or remarks to tell users about the sql, make it short.") @Description("thought or remarks to tell users about the sql, make it short.")
@@ -62,15 +69,17 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq); List<List<Text2SQLExemplar>> exemplarsList = promptHelper.getFewShotExemplars(llmReq);
// 2.generate sql generation prompt for each self-consistency inference // 2.generate sql generation prompt for each self-consistency inference
ChatApp chatApp = llmReq.getChatAppConfig().get(APP_KEY);
ChatLanguageModel chatLanguageModel = getChatLanguageModel(chatApp.getChatModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>(); Map<Prompt, List<Text2SQLExemplar>> prompt2Exemplar = new HashMap<>();
for (List<Text2SQLExemplar> exemplars : exemplarsList) { for (List<Text2SQLExemplar> exemplars : exemplarsList) {
llmReq.setDynamicExemplars(exemplars); llmReq.setDynamicExemplars(exemplars);
Prompt prompt = generatePrompt(llmReq, llmResp); Prompt prompt = generatePrompt(llmReq, llmResp, chatApp);
prompt2Exemplar.put(prompt, exemplars); prompt2Exemplar.put(prompt, exemplars);
} }
ChatLanguageModel chatLanguageModel = getChatLanguageModel(llmReq.getModelConfig());
SemanticSqlExtractor extractor =
AiServices.create(SemanticSqlExtractor.class, chatLanguageModel);
// 3.perform multiple self-consistency inferences parallelly // 3.perform multiple self-consistency inferences parallelly
Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>(); Map<String, Prompt> output2Prompt = new ConcurrentHashMap<>();
@@ -92,7 +101,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
return llmResp; return llmResp;
} }
private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) { private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp, ChatApp chatApp) {
StringBuilder exemplars = new StringBuilder(); StringBuilder exemplars = new StringBuilder();
for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) { for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) {
String exemplarStr = String.format("\nQuestion:%s,Schema:%s,SideInfo:%s,SQL:%s", String exemplarStr = String.format("\nQuestion:%s,Schema:%s,SideInfo:%s,SQL:%s",
@@ -112,10 +121,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy {
variable.put("information", sideInformation); variable.put("information", sideInformation);
// use custom prompt template if provided. // use custom prompt template if provided.
String promptTemplate = INSTRUCTION; String promptTemplate = chatApp.getPrompt();
if (StringUtils.isNotBlank(llmReq.getCustomPrompt())) {
promptTemplate = llmReq.getCustomPrompt();
}
return PromptTemplate.from(promptTemplate).apply(variable); return PromptTemplate.from(promptTemplate).apply(variable);
} }

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.chat.query.llm.s2sql;
import com.fasterxml.jackson.annotation.JsonValue; import com.fasterxml.jackson.annotation.JsonValue;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.ChatModelConfig; import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.common.pojo.Text2SQLExemplar; import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
import com.tencent.supersonic.headless.api.pojo.SchemaElement; import com.tencent.supersonic.headless.api.pojo.SchemaElement;
@@ -10,6 +11,7 @@ import org.apache.commons.collections4.CollectionUtils;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@@ -21,7 +23,7 @@ public class LLMReq {
private String currentDate; private String currentDate;
private String priorExts; private String priorExts;
private SqlGenType sqlGenType; private SqlGenType sqlGenType;
private ChatModelConfig modelConfig; private Map<String, ChatApp> chatAppConfig;
private String customPrompt; private String customPrompt;
private List<Text2SQLExemplar> dynamicExemplars; private List<Text2SQLExemplar> dynamicExemplars;

View File

@@ -16,9 +16,11 @@ import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.server.plugin.build.WebBase; import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.plugin.build.webpage.WebPageQuery; import com.tencent.supersonic.chat.server.plugin.build.webpage.WebPageQuery;
import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery; import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.*; import com.tencent.supersonic.common.pojo.enums.*;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail; import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
@@ -148,6 +150,7 @@ public class S2VisitsDemo extends S2BaseDemo {
agent.setEnableSearch(1); agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长", agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长")); "过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
// configure tools // configure tools
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
DatasetTool datasetTool = new DatasetTool(); DatasetTool datasetTool = new DatasetTool();
@@ -157,16 +160,10 @@ public class S2VisitsDemo extends S2BaseDemo {
toolConfig.getTools().add(datasetTool); toolConfig.getTools().add(datasetTool);
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat models // configure chat apps
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap(); Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, demoChatModel.getId()); chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
chatModelConfig.put(ChatModelType.MEMORY_REVIEW, demoChatModel.getId()); agent.setChatAppConfig(chatAppConfig);
chatModelConfig.put(ChatModelType.RESPONSE_GENERATE, demoChatModel.getId());
chatModelConfig.put(ChatModelType.MULTI_TURN_REWRITE, demoChatModel.getId());
agent.setChatModelConfig(chatModelConfig);
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true);
agent.setMultiTurnConfig(multiTurnConfig);
Agent agentCreated = agentService.createAgent(agent, defaultUser); Agent agentCreated = agentService.createAgent(agent, defaultUser);
return agentCreated.getId(); return agentCreated.getId();
} }

View File

@@ -26,10 +26,6 @@ public class SmallTalkDemo extends S2BaseDemo {
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平")); agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平"));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(true);
agent.setMultiTurnConfig(multiTurnConfig);
agentService.createAgent(agent, defaultUser); agentService.createAgent(agent, defaultUser);
} }

View File

@@ -387,3 +387,8 @@ CREATE TABLE IF NOT EXISTS `s2_chat_model` (
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表'; ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表';
ALTER TABLE s2_agent RENAME COLUMN config TO tool_config; ALTER TABLE s2_agent RENAME COLUMN config TO tool_config;
ALTER TABLE s2_agent RENAME COLUMN model_config TO chat_model_config; ALTER TABLE s2_agent RENAME COLUMN model_config TO chat_model_config;
--20241011
ALTER TABLE s2_agent DROP COLUMN prompt_config;
ALTER TABLE s2_agent DROP COLUMN multi_turn_config;
ALTER TABLE s2_agent DROP COLUMN enable_memory_review;

View File

@@ -391,15 +391,12 @@ CREATE TABLE IF NOT EXISTS s2_agent
tool_config varchar(2000) null, tool_config varchar(2000) null,
llm_config varchar(2000) null, llm_config varchar(2000) null,
chat_model_config varchar(6000) null, chat_model_config varchar(6000) null,
prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null, visual_config varchar(2000) null,
created_by varchar(100) null, created_by varchar(100) null,
created_at TIMESTAMP null, created_at TIMESTAMP null,
updated_by varchar(100) null, updated_by varchar(100) null,
updated_at TIMESTAMP null, updated_at TIMESTAMP null,
enable_search int null, enable_search int null,
enable_memory_review int null,
PRIMARY KEY (`id`) PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table'; ); COMMENT ON TABLE s2_agent IS 'agent information table';

View File

@@ -73,11 +73,8 @@ CREATE TABLE IF NOT EXISTS `s2_agent` (
`tool_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL, `tool_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`chat_model_config` text COLLATE utf8_unicode_ci DEFAULT NULL, `chat_model_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`prompt_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL, `visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`enable_search` tinyint DEFAULT 1, `enable_search` tinyint DEFAULT 1,
`enable_memory_review` tinyint DEFAULT 0,
`created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `created_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`created_at` datetime DEFAULT NULL, `created_at` datetime DEFAULT NULL,
`updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL, `updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,

View File

@@ -8,7 +8,9 @@ import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*; import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.pojo.ChatModel; import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.common.pojo.enums.ChatModelType; import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.chat.corrector.LLMSqlCorrector;
import com.tencent.supersonic.util.DataUtils; import com.tencent.supersonic.util.DataUtils;
import com.tencent.supersonic.util.LLMConfigUtils; import com.tencent.supersonic.util.LLMConfigUtils;
import org.junit.jupiter.api.*; import org.junit.jupiter.api.*;
@@ -135,22 +137,22 @@ public class Text2SQLEval extends BaseTest {
Agent agent = new Agent(); Agent agent = new Agent();
agent.setName("Agent for Test"); agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig(); ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getLLMQueryTool()); toolConfig.getTools().add(getDatasetTool());
agent.setToolConfig(JSONObject.toJSONString(toolConfig)); agent.setToolConfig(JSONObject.toJSONString(toolConfig));
ChatModel chatModel = new ChatModel(); ChatModel chatModel = new ChatModel();
chatModel.setName("Text2SQL LLM"); chatModel.setName("Text2SQL LLM");
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3)); chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser()); chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap(); Integer chatModelId = chatModel.getId();
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId()); // configure chat apps
agent.setChatModelConfig(chatModelConfig); Map<String, ChatApp> chatAppConfig = Maps.newHashMap(ChatAppManager.getAllApps());
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(); chatAppConfig.values().forEach(app -> app.setChatModelId(chatModelId));
multiTurnConfig.setEnableMultiTurn(enableMultiturn); chatAppConfig.get(LLMSqlCorrector.APP_KEY).setEnable(true);
agent.setMultiTurnConfig(multiTurnConfig); agent.setChatAppConfig(chatAppConfig);
return agent; return agent;
} }
private static DatasetTool getLLMQueryTool() { private static DatasetTool getDatasetTool() {
DatasetTool datasetTool = new DatasetTool(); DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET); datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(-1L)); datasetTool.setDataSetIds(Lists.newArrayList(-1L));

View File

@@ -391,15 +391,12 @@ CREATE TABLE IF NOT EXISTS s2_agent
tool_config varchar(2000) null, tool_config varchar(2000) null,
llm_config varchar(2000) null, llm_config varchar(2000) null,
chat_model_config varchar(6000) null, chat_model_config varchar(6000) null,
prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null, visual_config varchar(2000) null,
created_by varchar(100) null, created_by varchar(100) null,
created_at TIMESTAMP null, created_at TIMESTAMP null,
updated_by varchar(100) null, updated_by varchar(100) null,
updated_at TIMESTAMP null, updated_at TIMESTAMP null,
enable_search int null, enable_search int null,
enable_memory_review int null,
PRIMARY KEY (`id`) PRIMARY KEY (`id`)
); COMMENT ON TABLE s2_agent IS 'agent information table'; ); COMMENT ON TABLE s2_agent IS 'agent information table';