mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[feature][headless-chat]Introduce ChatApp to support more flexible chat model config.#1739
This commit is contained in:
@@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.server.agent;
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import lombok.Data;
|
||||
@@ -24,6 +25,7 @@ public class Agent extends RecordInfo {
|
||||
private Integer enableMemoryReview;
|
||||
private String toolConfig;
|
||||
private Map<ChatModelType, Integer> chatModelConfig = Collections.EMPTY_MAP;
|
||||
private Map<String, ChatApp> chatAppConfig = Collections.EMPTY_MAP;
|
||||
private PromptConfig promptConfig;
|
||||
private MultiTurnConfig multiTurnConfig;
|
||||
private VisualConfig visualConfig;
|
||||
|
||||
@@ -8,8 +8,8 @@ import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||
import dev.langchain4j.data.message.AiMessage;
|
||||
@@ -28,26 +28,35 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT
|
||||
|
||||
public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private static final String APP_KEY = "SMALL_TALK";
|
||||
private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n"
|
||||
+ "#Task: Respond quickly and nicely to the user."
|
||||
+ "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n"
|
||||
+ "#Current Input: %s\n" + "#Your response: ";
|
||||
|
||||
public PlainTextExecutor() {
|
||||
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("闲聊对话")
|
||||
.description("直接将原始输入透传大模型").enable(true).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public QueryResult execute(ExecuteContext executeContext) {
|
||||
if (!"PLAIN_TEXT".equals(executeContext.getParseInfo().getQueryMode())) {
|
||||
return null;
|
||||
}
|
||||
|
||||
String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
||||
if (!chatApp.isEnable()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.RESPONSE_GENERATE));
|
||||
String promptStr = String.format(chatApp.getPrompt(), getHistoryInputs(executeContext),
|
||||
executeContext.getQueryText());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(chatApp.getChatModelConfig());
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
QueryResult result = new QueryResult();
|
||||
@@ -60,25 +69,12 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
||||
|
||||
private String getHistoryInputs(ExecuteContext executeContext) {
|
||||
StringBuilder historyInput = new StringBuilder();
|
||||
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
||||
queryResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
|
||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = chatAgent.getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
Boolean multiTurnConfig =
|
||||
agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn()
|
||||
: globalMultiTurnConfig;
|
||||
|
||||
if (Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
|
||||
queryResps.stream().forEach(p -> {
|
||||
historyInput.append(p.getQueryText());
|
||||
historyInput.append(";");
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return historyInput.toString();
|
||||
}
|
||||
|
||||
@@ -6,7 +6,8 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||
import dev.langchain4j.model.input.Prompt;
|
||||
import dev.langchain4j.model.input.PromptTemplate;
|
||||
@@ -29,11 +30,11 @@ public class MemoryReviewTask {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY = "MEMORY_REVIEW";
|
||||
private static final String INSTRUCTION = ""
|
||||
+ "\n#Role: You are a senior data engineer experienced in writing SQL."
|
||||
+ "\n#Task: Your will be provided with a user question and the SQL written by a junior engineer,"
|
||||
+ "please take a review and give your opinion."
|
||||
+ "\n#Rules: "
|
||||
+ "please take a review and give your opinion." + "\n#Rules: "
|
||||
+ "\n1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
|
||||
+ "\n2.NO NEED to check date filters as the junior engineer seldom makes mistakes in this regard."
|
||||
+ "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s"
|
||||
@@ -47,6 +48,11 @@ public class MemoryReviewTask {
|
||||
@Autowired
|
||||
private AgentService agentService;
|
||||
|
||||
public MemoryReviewTask() {
|
||||
ChatAppManager.register(ChatApp.builder().key(APP_KEY).prompt(INSTRUCTION).name("记忆启用评估")
|
||||
.description("通过大模型对记忆做正确性评估以决定是否启用").enable(false).build());
|
||||
}
|
||||
|
||||
@Scheduled(fixedDelay = 60 * 1000)
|
||||
public void review() {
|
||||
try {
|
||||
@@ -58,16 +64,22 @@ public class MemoryReviewTask {
|
||||
|
||||
private void processMemory(ChatMemoryDO m) {
|
||||
Agent chatAgent = agentService.getAgent(m.getAgentId());
|
||||
if (Objects.isNull(chatAgent) || !chatAgent.enableMemoryReview()) {
|
||||
log.debug("Agent id {} not found or memory review disabled", m.getAgentId());
|
||||
if (Objects.isNull(chatAgent)) {
|
||||
log.warn("Agent id {} not found or memory review disabled", m.getAgentId());
|
||||
return;
|
||||
}
|
||||
String promptStr = createPromptString(m);
|
||||
|
||||
ChatApp chatApp = chatAgent.getChatAppConfig().get(APP_KEY);
|
||||
if (!chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String promptStr = createPromptString(m, chatApp.getPrompt());
|
||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||
|
||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.MEMORY_REVIEW));
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
|
||||
if (Objects.nonNull(chatLanguageModel)) {
|
||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||
@@ -77,8 +89,8 @@ public class MemoryReviewTask {
|
||||
}
|
||||
}
|
||||
|
||||
private String createPromptString(ChatMemoryDO m) {
|
||||
return String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
|
||||
private String createPromptString(ChatMemoryDO m, String promptTemplate) {
|
||||
return String.format(promptTemplate, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
|
||||
m.getS2sql());
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package com.tencent.supersonic.chat.server.parser;
|
||||
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
@@ -11,10 +9,10 @@ import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||
import com.tencent.supersonic.headless.api.pojo.SchemaElementMatch;
|
||||
@@ -47,7 +45,6 @@ import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULTI_TURN_ENABLE;
|
||||
import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_EXEMPLAR_RECALL_NUMBER;
|
||||
|
||||
@Slf4j
|
||||
@@ -55,6 +52,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
|
||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||
|
||||
public static final String APP_KEY_MULTI_TURN = "REWRITE_MULTI_TURN";
|
||||
private static final String REWRITE_MULTI_TURN_INSTRUCTION = ""
|
||||
+ "#Role: You are a data product manager experienced in data requirements."
|
||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||
@@ -68,6 +66,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
+ "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}"
|
||||
+ "#Rewritten Question: ";
|
||||
|
||||
public static final String APP_KEY_ERROR_MESSAGE = "REWRITE_ERROR_MESSAGE";
|
||||
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
|
||||
+ "#Role: You are a data business partner who closely interacts with business people.\n"
|
||||
+ "#Task: Your will be provided with user input, system output and some examples, "
|
||||
@@ -77,6 +76,16 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
+ "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
|
||||
+ "#Examples: {{examples}}\n" + "#Response: ";
|
||||
|
||||
public NL2SQLParser() {
|
||||
ChatAppManager.register(
|
||||
ChatApp.builder().key(APP_KEY_MULTI_TURN).prompt(REWRITE_MULTI_TURN_INSTRUCTION)
|
||||
.name("多轮对话改写").description("通过大模型根据历史对话来改写本轮对话").enable(false).build());
|
||||
|
||||
ChatAppManager.register(ChatApp.builder().key(APP_KEY_ERROR_MESSAGE)
|
||||
.prompt(REWRITE_ERROR_MESSAGE_INSTRUCTION).name("异常提示改写")
|
||||
.description("通过大模型将异常信息改写为更友好和引导性的提示用语").enable(false).build());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void parse(ParseContext parseContext, ParseResp parseResp) {
|
||||
if (!parseContext.enableNL2SQL() || checkSkip(parseResp)) {
|
||||
@@ -97,10 +106,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
|
||||
} else {
|
||||
if (!parseContext.isDisableLLM()) {
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
|
||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
|
||||
parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig(
|
||||
parseContext.getAgent(), ChatModelType.RESPONSE_GENERATE)));
|
||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext,
|
||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars()));
|
||||
}
|
||||
}
|
||||
parseResp.setState(text2SqlParseResp.getState());
|
||||
@@ -162,22 +169,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
}
|
||||
|
||||
private void processMultiTurn(ParseContext parseContext) {
|
||||
Agent agent = parseContext.getAgent();
|
||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||
MultiTurnConfig agentMultiTurnConfig = agent.getMultiTurnConfig();
|
||||
Boolean globalMultiTurnConfig =
|
||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||
|
||||
Boolean multiTurnConfig =
|
||||
agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn()
|
||||
: globalMultiTurnConfig;
|
||||
if (!Boolean.TRUE.equals(multiTurnConfig)) {
|
||||
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_MULTI_TURN);
|
||||
if (!chatApp.isEnable()) {
|
||||
return;
|
||||
}
|
||||
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.MULTI_TURN_REWRITE));
|
||||
|
||||
// derive mapping result of current question and parsing result of last question.
|
||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||
QueryNLReq queryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||
@@ -203,9 +199,11 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
variables.put("history_schema", histMapStr);
|
||||
variables.put("history_sql", histSQL);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_MULTI_TURN_INSTRUCTION).apply(variables);
|
||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
|
||||
keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text());
|
||||
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
String rewrittenQuery = response.content().text();
|
||||
keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery);
|
||||
@@ -217,24 +215,30 @@ public class NL2SQLParser implements ChatQueryParser {
|
||||
currentMapResult.getQueryText(), rewrittenQuery);
|
||||
}
|
||||
|
||||
private String rewriteErrorMessage(String userQuestion, String errMsg,
|
||||
List<Text2SQLExemplar> similarExemplars, List<String> agentExamples,
|
||||
ChatModelConfig modelConfig) {
|
||||
private String rewriteErrorMessage(ParseContext parseContext, String errMsg,
|
||||
List<Text2SQLExemplar> similarExemplars) {
|
||||
|
||||
ChatApp chatApp = parseContext.getAgent().getChatAppConfig().get(APP_KEY_ERROR_MESSAGE);
|
||||
if (!chatApp.isEnable()) {
|
||||
return errMsg;
|
||||
}
|
||||
|
||||
Map<String, Object> variables = new HashMap<>();
|
||||
variables.put("user_question", userQuestion);
|
||||
variables.put("user_question", parseContext.getQueryText());
|
||||
variables.put("system_message", errMsg);
|
||||
|
||||
StringBuilder exampleStr = new StringBuilder();
|
||||
similarExemplars.forEach(e -> exampleStr.append(
|
||||
String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
|
||||
agentExamples.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||
parseContext.getAgent().getExamples()
|
||||
.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
|
||||
variables.put("examples", exampleStr);
|
||||
|
||||
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||
Prompt prompt = PromptTemplate.from(chatApp.getPrompt()).apply(variables);
|
||||
keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text());
|
||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||
ChatLanguageModel chatLanguageModel =
|
||||
ModelProvider.getChatModel(ModelConfigHelper.getChatModelConfig(chatApp));
|
||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||
|
||||
String rewrittenMsg = response.content().text();
|
||||
keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg);
|
||||
|
||||
|
||||
@@ -33,15 +33,9 @@ public class AgentDO {
|
||||
|
||||
private Integer enableSearch;
|
||||
|
||||
private Integer enableMemoryReview;
|
||||
|
||||
private String toolConfig;
|
||||
|
||||
private String chatModelConfig;
|
||||
|
||||
private String multiTurnConfig;
|
||||
|
||||
private String visualConfig;
|
||||
|
||||
private String promptConfig;
|
||||
}
|
||||
|
||||
@@ -10,12 +10,15 @@ import com.tencent.supersonic.chat.server.config.ChatModelParameters;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.Parameter;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.util.ChatAppManager;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
@@ -51,11 +54,9 @@ public class ChatModelController {
|
||||
return chatModelService.getChatModels();
|
||||
}
|
||||
|
||||
@RequestMapping("/getModelTypeList")
|
||||
public List<ChatModelTypeResp> getModelTypeList() {
|
||||
return Arrays.stream(ChatModelType.values()).map(t -> ChatModelTypeResp.builder()
|
||||
.type(t.toString()).name(t.getName()).description(t.getDescription()).build())
|
||||
.collect(Collectors.toList());
|
||||
@RequestMapping("/getModelAppList")
|
||||
public List<ChatApp> getModelAppList() {
|
||||
return new ArrayList(ChatAppManager.getAllApps().values());
|
||||
}
|
||||
|
||||
@RequestMapping("/getModelParameters")
|
||||
|
||||
@@ -15,8 +15,8 @@ import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
@@ -26,6 +26,7 @@ import org.springframework.stereotype.Service;
|
||||
import org.springframework.util.CollectionUtils;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
@@ -53,12 +54,6 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
|
||||
@Override
|
||||
public Agent createAgent(Agent agent, User user) {
|
||||
if (Objects.isNull(agent.getPromptConfig())
|
||||
|| Objects.isNull(agent.getPromptConfig().getPromptTemplate())) {
|
||||
PromptConfig promptConfig = new PromptConfig();
|
||||
promptConfig.setPromptTemplate(OnePassSCSqlGenStrategy.INSTRUCTION.trim());
|
||||
agent.setPromptConfig(promptConfig);
|
||||
}
|
||||
agent.createdBy(user.getName());
|
||||
AgentDO agentDO = convert(agent);
|
||||
save(agentDO);
|
||||
@@ -69,12 +64,6 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
|
||||
@Override
|
||||
public Agent updateAgent(Agent agent, User user) {
|
||||
if (Objects.isNull(agent.getPromptConfig())
|
||||
|| Objects.isNull(agent.getPromptConfig().getPromptTemplate())) {
|
||||
PromptConfig promptConfig = new PromptConfig();
|
||||
promptConfig.setPromptTemplate(OnePassSCSqlGenStrategy.INSTRUCTION.trim());
|
||||
agent.setPromptConfig(promptConfig);
|
||||
}
|
||||
agent.updatedBy(user.getName());
|
||||
updateById(convert(agent));
|
||||
executeAgentExamplesAsync(agent);
|
||||
@@ -105,10 +94,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
}
|
||||
|
||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||
if (!agent.containsDatasetTool()
|
||||
|| !agent.enableMemoryReview()
|
||||
|| !ModelConfigHelper.testConnection(
|
||||
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL))
|
||||
if (!agent.containsDatasetTool() || !agent.enableMemoryReview()
|
||||
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||
return;
|
||||
}
|
||||
@@ -144,11 +130,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
BeanUtils.copyProperties(agentDO, agent);
|
||||
agent.setToolConfig(agentDO.getToolConfig());
|
||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||
agent.setChatModelConfig(
|
||||
JsonUtil.toMap(agentDO.getChatModelConfig(), ChatModelType.class, Integer.class));
|
||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||
agent.setMultiTurnConfig(
|
||||
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||
agent.setChatAppConfig(
|
||||
JsonUtil.toMap(agentDO.getChatModelConfig(), String.class, ChatApp.class));
|
||||
agent.setVisualConfig(JsonUtil.toObject(agentDO.getVisualConfig(), VisualConfig.class));
|
||||
return agent;
|
||||
}
|
||||
@@ -158,10 +141,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
||||
BeanUtils.copyProperties(agent, agentDO);
|
||||
agentDO.setToolConfig(agent.getToolConfig());
|
||||
agentDO.setExamples(JsonUtil.toString(agent.getExamples()));
|
||||
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatModelConfig()));
|
||||
agentDO.setMultiTurnConfig(JsonUtil.toString(agent.getMultiTurnConfig()));
|
||||
agentDO.setChatModelConfig(JsonUtil.toString(agent.getChatAppConfig()));
|
||||
agentDO.setVisualConfig(JsonUtil.toString(agent.getVisualConfig()));
|
||||
agentDO.setPromptConfig(JsonUtil.toString(agent.getPromptConfig()));
|
||||
if (agentDO.getStatus() == null) {
|
||||
agentDO.setStatus(1);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.server.util;
|
||||
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.pojo.ChatApp;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||
@@ -27,13 +28,10 @@ public class ModelConfigHelper {
|
||||
}
|
||||
}
|
||||
|
||||
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) {
|
||||
ChatModelConfig chatModelConfig = null;
|
||||
if (agent.getChatModelConfig().containsKey(modelType)) {
|
||||
Integer chatModelId = agent.getChatModelConfig().get(modelType);
|
||||
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
|
||||
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
|
||||
}
|
||||
public static ChatModelConfig getChatModelConfig(ChatApp chatApp) {
|
||||
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
|
||||
ChatModelConfig chatModelConfig =
|
||||
chatModelService.getChatModel(chatApp.getChatModelId()).getConfig();
|
||||
return chatModelConfig;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,12 @@ package com.tencent.supersonic.chat.server.util;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||
import com.tencent.supersonic.common.util.BeanMapper;
|
||||
import com.tencent.supersonic.common.util.ContextUtils;
|
||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||
import org.apache.commons.collections.MapUtils;
|
||||
|
||||
@@ -37,10 +39,7 @@ public class QueryReqConverter {
|
||||
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
|
||||
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
||||
}
|
||||
ChatModelConfig chatModelConfig =
|
||||
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
|
||||
queryNLReq.setModelConfig(chatModelConfig);
|
||||
queryNLReq.setCustomPrompt(agent.getPromptConfig().getPromptTemplate());
|
||||
queryNLReq.setChatAppConfig(parseContext.getAgent().getChatAppConfig());
|
||||
if (chatCtx != null) {
|
||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user