mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:51:00 +00:00
[feature][chat]Introduce dedicated LLM management.#1739
This commit is contained in:
@@ -5,17 +5,12 @@ import com.google.common.collect.Lists;
|
|||||||
import com.google.common.collect.Sets;
|
import com.google.common.collect.Sets;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
|
||||||
import com.tencent.supersonic.common.pojo.RecordInfo;
|
import com.tencent.supersonic.common.pojo.RecordInfo;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import org.springframework.util.CollectionUtils;
|
import org.springframework.util.CollectionUtils;
|
||||||
|
|
||||||
import java.util.Collection;
|
import java.util.*;
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.Objects;
|
|
||||||
import java.util.Set;
|
|
||||||
import java.util.stream.Collectors;
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
@@ -29,10 +24,9 @@ public class Agent extends RecordInfo {
|
|||||||
|
|
||||||
/** 0 offline, 1 online */
|
/** 0 offline, 1 online */
|
||||||
private Integer status;
|
private Integer status;
|
||||||
|
|
||||||
private List<String> examples;
|
private List<String> examples;
|
||||||
private String agentConfig;
|
private String agentConfig;
|
||||||
private ChatModelConfig modelConfig;
|
private Map<ChatModelType, Integer> modelConfig = Collections.EMPTY_MAP;
|
||||||
private PromptConfig promptConfig;
|
private PromptConfig promptConfig;
|
||||||
private MultiTurnConfig multiTurnConfig;
|
private MultiTurnConfig multiTurnConfig;
|
||||||
private VisualConfig visualConfig;
|
private VisualConfig visualConfig;
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import com.tencent.supersonic.chat.server.parser.ParserConfig;
|
|||||||
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
import com.tencent.supersonic.chat.server.pojo.ExecuteContext;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
|
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
|
||||||
import dev.langchain4j.data.message.AiMessage;
|
import dev.langchain4j.data.message.AiMessage;
|
||||||
@@ -44,8 +46,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
|
|||||||
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
AgentService agentService = ContextUtils.getBean(AgentService.class);
|
||||||
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
Agent chatAgent = agentService.getAgent(executeContext.getAgent().getId());
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel =
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||||
ModelProvider.getChatModel(chatAgent.getModelConfig());
|
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.RESPONSE_GENERATE));
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
QueryResult result = new QueryResult();
|
QueryResult result = new QueryResult();
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import com.tencent.supersonic.chat.server.agent.Agent;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
|
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.model.input.Prompt;
|
import dev.langchain4j.model.input.Prompt;
|
||||||
import dev.langchain4j.model.input.PromptTemplate;
|
import dev.langchain4j.model.input.PromptTemplate;
|
||||||
@@ -63,8 +65,8 @@ public class MemoryReviewTask {
|
|||||||
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
|
||||||
|
|
||||||
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
keyPipelineLog.info("MemoryReviewTask reqPrompt:\n{}", promptStr);
|
||||||
ChatLanguageModel chatLanguageModel =
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||||
ModelProvider.getChatModel(chatAgent.getModelConfig());
|
ModelConfigHelper.getChatModelConfig(chatAgent, ChatModelType.MEMORY_REVIEW));
|
||||||
if (Objects.nonNull(chatLanguageModel)) {
|
if (Objects.nonNull(chatLanguageModel)) {
|
||||||
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
String response = chatLanguageModel.generate(prompt.toUserMessage()).content().text();
|
||||||
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
keyPipelineLog.info("MemoryReviewTask modelResp:\n{}", response);
|
||||||
|
|||||||
@@ -1,16 +1,19 @@
|
|||||||
package com.tencent.supersonic.chat.server.parser;
|
package com.tencent.supersonic.chat.server.parser;
|
||||||
|
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResp;
|
||||||
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
import com.tencent.supersonic.chat.server.plugin.PluginQueryManager;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
import com.tencent.supersonic.chat.server.service.ChatContextService;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||||
|
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||||
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
|
||||||
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
import com.tencent.supersonic.common.config.EmbeddingConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
import com.tencent.supersonic.common.pojo.Text2SQLExemplar;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
import com.tencent.supersonic.common.service.impl.ExemplarServiceImpl;
|
||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -52,7 +55,7 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
|
|
||||||
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
|
||||||
|
|
||||||
private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
|
private static final String REWRITE_MULTI_TURN_INSTRUCTION = ""
|
||||||
+ "#Role: You are a data product manager experienced in data requirements."
|
+ "#Role: You are a data product manager experienced in data requirements."
|
||||||
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
+ "#Task: Your will be provided with current and history questions asked by a user,"
|
||||||
+ "along with their mapped schema elements(metric, dimension and value),"
|
+ "along with their mapped schema elements(metric, dimension and value),"
|
||||||
@@ -96,8 +99,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
if (parseContext.enbaleLLM()) {
|
if (parseContext.enbaleLLM()) {
|
||||||
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
|
parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
|
||||||
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
|
text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
|
||||||
parseContext.getAgent().getExamples(),
|
parseContext.getAgent().getExamples(), ModelConfigHelper.getChatModelConfig(
|
||||||
parseContext.getAgent().getModelConfig()));
|
parseContext.getAgent(), ChatModelType.RESPONSE_GENERATE)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
parseResp.setState(text2SqlParseResp.getState());
|
parseResp.setState(text2SqlParseResp.getState());
|
||||||
@@ -158,8 +161,9 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private void processMultiTurn(ParseContext parseContext) {
|
private void processMultiTurn(ParseContext parseContext) {
|
||||||
|
Agent agent = parseContext.getAgent();
|
||||||
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
|
||||||
MultiTurnConfig agentMultiTurnConfig = parseContext.getAgent().getMultiTurnConfig();
|
MultiTurnConfig agentMultiTurnConfig = agent.getMultiTurnConfig();
|
||||||
Boolean globalMultiTurnConfig =
|
Boolean globalMultiTurnConfig =
|
||||||
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
|
||||||
|
|
||||||
@@ -170,8 +174,8 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatLanguageModel chatLanguageModel =
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(
|
||||||
ModelProvider.getChatModel(parseContext.getAgent().getModelConfig());
|
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.MULTI_TURN_REWRITE));
|
||||||
|
|
||||||
// derive mapping result of current question and parsing result of last question.
|
// derive mapping result of current question and parsing result of last question.
|
||||||
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
ChatLayerService chatLayerService = ContextUtils.getBean(ChatLayerService.class);
|
||||||
@@ -198,12 +202,12 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
variables.put("history_schema", histMapStr);
|
variables.put("history_schema", histMapStr);
|
||||||
variables.put("history_sql", histSQL);
|
variables.put("history_sql", histSQL);
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(REWRITE_USER_QUESTION_INSTRUCTION).apply(variables);
|
Prompt prompt = PromptTemplate.from(REWRITE_MULTI_TURN_INSTRUCTION).apply(variables);
|
||||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
keyPipelineLog.info("QueryRewrite reqPrompt:{}", prompt.text());
|
||||||
|
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
String rewrittenQuery = response.content().text();
|
String rewrittenQuery = response.content().text();
|
||||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenQuery);
|
keyPipelineLog.info("QueryRewrite modelResp:{}", rewrittenQuery);
|
||||||
parseContext.setQueryText(rewrittenQuery);
|
parseContext.setQueryText(rewrittenQuery);
|
||||||
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
|
||||||
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
|
MapResp rewrittenQueryMapResult = chatLayerService.map(rewrittenQueryNLReq);
|
||||||
@@ -226,12 +230,12 @@ public class NL2SQLParser implements ChatQueryParser {
|
|||||||
variables.put("examples", exampleStr);
|
variables.put("examples", exampleStr);
|
||||||
|
|
||||||
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
Prompt prompt = PromptTemplate.from(REWRITE_ERROR_MESSAGE_INSTRUCTION).apply(variables);
|
||||||
keyPipelineLog.info("NL2SQLParser reqPrompt:{}", prompt.text());
|
keyPipelineLog.info("ErrorRewrite reqPrompt:{}", prompt.text());
|
||||||
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
ChatLanguageModel chatLanguageModel = ModelProvider.getChatModel(modelConfig);
|
||||||
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());
|
||||||
|
|
||||||
String rewrittenMsg = response.content().text();
|
String rewrittenMsg = response.content().text();
|
||||||
keyPipelineLog.info("NL2SQLParser modelResp:{}", rewrittenMsg);
|
keyPipelineLog.info("ErrorRewrite modelResp:{}", rewrittenMsg);
|
||||||
|
|
||||||
return rewrittenMsg;
|
return rewrittenMsg;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,7 +45,9 @@ public class AgentDO {
|
|||||||
private Integer enableSearch;
|
private Integer enableSearch;
|
||||||
|
|
||||||
private Integer enableMemoryReview;
|
private Integer enableMemoryReview;
|
||||||
|
|
||||||
private String modelConfig;
|
private String modelConfig;
|
||||||
|
|
||||||
private String multiTurnConfig;
|
private String multiTurnConfig;
|
||||||
|
|
||||||
private String visualConfig;
|
private String visualConfig;
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.persistence.dataobject;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.annotation.IdType;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableId;
|
||||||
|
import com.baomidou.mybatisplus.annotation.TableName;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.Date;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@TableName("s2_chat_model")
|
||||||
|
public class ChatModelDO {
|
||||||
|
@TableId(type = IdType.AUTO)
|
||||||
|
private Integer id;
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
private String config;
|
||||||
|
|
||||||
|
private Date createdAt;
|
||||||
|
|
||||||
|
private String createdBy;
|
||||||
|
|
||||||
|
private Date updatedAt;
|
||||||
|
|
||||||
|
private String updatedBy;
|
||||||
|
|
||||||
|
private String admin;
|
||||||
|
|
||||||
|
private String viewer;
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.persistence.mapper;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatModelDO;
|
||||||
|
import org.apache.ibatis.annotations.Mapper;
|
||||||
|
|
||||||
|
@Mapper
|
||||||
|
public interface ChatModelMapper extends BaseMapper<ChatModelDO> {
|
||||||
|
}
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.pojo;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.Date;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class ChatModel {
|
||||||
|
private Integer id;
|
||||||
|
|
||||||
|
private String name;
|
||||||
|
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
private ChatModelConfig Config;
|
||||||
|
|
||||||
|
private Date createdAt;
|
||||||
|
|
||||||
|
private String createdBy;
|
||||||
|
|
||||||
|
private Date updatedAt;
|
||||||
|
|
||||||
|
private String updatedBy;
|
||||||
|
|
||||||
|
private String admin;
|
||||||
|
|
||||||
|
private String viewer;
|
||||||
|
}
|
||||||
@@ -8,7 +8,7 @@ import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
|||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.web.bind.annotation.DeleteMapping;
|
import org.springframework.web.bind.annotation.DeleteMapping;
|
||||||
@@ -61,6 +61,6 @@ public class AgentController {
|
|||||||
|
|
||||||
@PostMapping("/testLLMConn")
|
@PostMapping("/testLLMConn")
|
||||||
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
|
public boolean testLLMConn(@RequestBody ChatModelConfig modelConfig) {
|
||||||
return LLMConnHelper.testConnection(modelConfig);
|
return ModelConfigHelper.testConnection(modelConfig);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.rest;
|
||||||
|
|
||||||
|
import javax.servlet.http.HttpServletRequest;
|
||||||
|
import javax.servlet.http.HttpServletResponse;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||||
|
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@RestController
|
||||||
|
@RequestMapping({"/api/chat/model", "/openapi/chat/model"})
|
||||||
|
public class ChatModelController {
|
||||||
|
@Autowired
|
||||||
|
private ChatModelService chatModelService;
|
||||||
|
|
||||||
|
@PostMapping
|
||||||
|
public ChatModel createModel(@RequestBody ChatModel model,
|
||||||
|
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||||
|
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||||
|
return chatModelService.createChatModel(model, user);
|
||||||
|
}
|
||||||
|
|
||||||
|
@PutMapping
|
||||||
|
public ChatModel updateModel(@RequestBody ChatModel model,
|
||||||
|
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
|
||||||
|
User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
|
||||||
|
return chatModelService.updateChatModel(model, user);
|
||||||
|
}
|
||||||
|
|
||||||
|
@DeleteMapping("/{id}")
|
||||||
|
public boolean deleteModel(@PathVariable("id") Integer id) {
|
||||||
|
chatModelService.deleteChatModel(id);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@RequestMapping("/getModelList")
|
||||||
|
public List<ChatModel> getModelList() {
|
||||||
|
return chatModelService.getChatModels();
|
||||||
|
}
|
||||||
|
|
||||||
|
@PostMapping("/testConnection")
|
||||||
|
public boolean testConnection(@RequestBody ChatModelConfig modelConfig) {
|
||||||
|
return ModelConfigHelper.testConnection(modelConfig);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface ChatModelService {
|
||||||
|
List<ChatModel> getChatModels();
|
||||||
|
|
||||||
|
ChatModel getChatModel(Integer id);
|
||||||
|
|
||||||
|
ChatModel createChatModel(ChatModel chatModel, User user);
|
||||||
|
|
||||||
|
ChatModel updateChatModel(ChatModel chatModel, User user);
|
||||||
|
|
||||||
|
void deleteChatModel(Integer id);
|
||||||
|
}
|
||||||
@@ -10,12 +10,13 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
|
|||||||
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
|
||||||
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
import com.tencent.supersonic.chat.server.persistence.mapper.AgentDOMapper;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||||
import com.tencent.supersonic.chat.server.service.MemoryService;
|
import com.tencent.supersonic.chat.server.service.MemoryService;
|
||||||
import com.tencent.supersonic.chat.server.util.LLMConnHelper;
|
import com.tencent.supersonic.chat.server.util.ModelConfigHelper;
|
||||||
import com.tencent.supersonic.common.config.PromptConfig;
|
import com.tencent.supersonic.common.config.PromptConfig;
|
||||||
import com.tencent.supersonic.common.config.VisualConfig;
|
import com.tencent.supersonic.common.config.VisualConfig;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import com.tencent.supersonic.common.util.JsonUtil;
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
|
import com.tencent.supersonic.headless.chat.parser.llm.OnePassSCSqlGenStrategy;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
@@ -40,6 +41,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ChatQueryService chatQueryService;
|
private ChatQueryService chatQueryService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ChatModelService chatModelService;
|
||||||
|
|
||||||
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
private ExecutorService executorService = Executors.newFixedThreadPool(1);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -101,7 +105,9 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
}
|
}
|
||||||
|
|
||||||
private synchronized void doExecuteAgentExamples(Agent agent) {
|
private synchronized void doExecuteAgentExamples(Agent agent) {
|
||||||
if (!agent.containsLLMTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|
if (!agent.containsLLMTool()
|
||||||
|
|| !ModelConfigHelper.testConnection(
|
||||||
|
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL))
|
||||||
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
|| CollectionUtils.isEmpty(agent.getExamples())) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -136,7 +142,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
|
|||||||
BeanUtils.copyProperties(agentDO, agent);
|
BeanUtils.copyProperties(agentDO, agent);
|
||||||
agent.setAgentConfig(agentDO.getConfig());
|
agent.setAgentConfig(agentDO.getConfig());
|
||||||
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
agent.setExamples(JsonUtil.toList(agentDO.getExamples(), String.class));
|
||||||
agent.setModelConfig(JsonUtil.toObject(agentDO.getModelConfig(), ChatModelConfig.class));
|
agent.setModelConfig(
|
||||||
|
JsonUtil.toMap(agentDO.getModelConfig(), ChatModelType.class, Integer.class));
|
||||||
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
agent.setPromptConfig(JsonUtil.toObject(agentDO.getPromptConfig(), PromptConfig.class));
|
||||||
agent.setMultiTurnConfig(
|
agent.setMultiTurnConfig(
|
||||||
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
JsonUtil.toObject(agentDO.getMultiTurnConfig(), MultiTurnConfig.class));
|
||||||
|
|||||||
@@ -0,0 +1,89 @@
|
|||||||
|
package com.tencent.supersonic.chat.server.service.impl;
|
||||||
|
|
||||||
|
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.dataobject.ChatModelDO;
|
||||||
|
import com.tencent.supersonic.chat.server.persistence.mapper.ChatModelMapper;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import com.tencent.supersonic.common.util.JsonUtil;
|
||||||
|
import com.tencent.supersonic.common.util.StringUtil;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.apache.commons.lang.StringUtils;
|
||||||
|
import org.springframework.beans.BeanUtils;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import java.util.Date;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
@Service
|
||||||
|
public class ChatModelServiceImpl extends ServiceImpl<ChatModelMapper, ChatModelDO>
|
||||||
|
implements ChatModelService {
|
||||||
|
@Override
|
||||||
|
public List<ChatModel> getChatModels() {
|
||||||
|
return list().stream().map(this::convert).collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatModel getChatModel(Integer id) {
|
||||||
|
if (id == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return convert(getById(id));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatModel createChatModel(ChatModel chatModel, User user) {
|
||||||
|
ChatModelDO chatModelDO = convert(chatModel);
|
||||||
|
chatModelDO.setCreatedBy(user.getName());
|
||||||
|
chatModelDO.setCreatedAt(new Date());
|
||||||
|
chatModelDO.setUpdatedBy(user.getName());
|
||||||
|
chatModelDO.setUpdatedAt(new Date());
|
||||||
|
if (StringUtils.isBlank(chatModel.getAdmin())) {
|
||||||
|
chatModelDO.setAdmin(user.getName());
|
||||||
|
}
|
||||||
|
save(chatModelDO);
|
||||||
|
return chatModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ChatModel updateChatModel(ChatModel chatModel, User user) {
|
||||||
|
ChatModelDO chatModelDO = convert(chatModel);
|
||||||
|
chatModelDO.setUpdatedBy(user.getName());
|
||||||
|
chatModelDO.setUpdatedAt(new Date());
|
||||||
|
if (StringUtils.isBlank(chatModel.getAdmin())) {
|
||||||
|
chatModel.setAdmin(user.getName());
|
||||||
|
}
|
||||||
|
updateById(chatModelDO);
|
||||||
|
return chatModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void deleteChatModel(Integer id) {
|
||||||
|
removeById(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ChatModel convert(ChatModelDO chatModelDO) {
|
||||||
|
if (chatModelDO == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
ChatModel chatModel = new ChatModel();
|
||||||
|
BeanUtils.copyProperties(chatModelDO, chatModel);
|
||||||
|
chatModel.setConfig(JsonUtil.toObject(chatModelDO.getConfig(), ChatModelConfig.class));
|
||||||
|
return chatModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
private ChatModelDO convert(ChatModel chatModel) {
|
||||||
|
if (chatModel == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
ChatModelDO chatModelDO = new ChatModelDO();
|
||||||
|
BeanUtils.copyProperties(chatModel, chatModelDO);
|
||||||
|
chatModelDO.setConfig(JsonUtil.toString(chatModel.getConfig()));
|
||||||
|
return chatModelDO;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,14 +1,18 @@
|
|||||||
package com.tencent.supersonic.chat.server.util;
|
package com.tencent.supersonic.chat.server.util;
|
||||||
|
|
||||||
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||||
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
|
||||||
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import dev.langchain4j.model.chat.ChatLanguageModel;
|
import dev.langchain4j.model.chat.ChatLanguageModel;
|
||||||
import dev.langchain4j.provider.ModelProvider;
|
import dev.langchain4j.provider.ModelProvider;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class LLMConnHelper {
|
public class ModelConfigHelper {
|
||||||
public static boolean testConnection(ChatModelConfig modelConfig) {
|
public static boolean testConnection(ChatModelConfig modelConfig) {
|
||||||
try {
|
try {
|
||||||
if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
if (modelConfig == null || StringUtils.isBlank(modelConfig.getBaseUrl())) {
|
||||||
@@ -22,4 +26,14 @@ public class LLMConnHelper {
|
|||||||
throw new InvalidArgumentException(e.getMessage());
|
throw new InvalidArgumentException(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static ChatModelConfig getChatModelConfig(Agent agent, ChatModelType modelType) {
|
||||||
|
ChatModelConfig chatModelConfig = null;
|
||||||
|
if (agent.getModelConfig().containsKey(modelType)) {
|
||||||
|
Integer chatModelId = agent.getModelConfig().get(modelType);
|
||||||
|
ChatModelService chatModelService = ContextUtils.getBean(ChatModelService.class);
|
||||||
|
chatModelConfig = chatModelService.getChatModel(chatModelId).getConfig();
|
||||||
|
}
|
||||||
|
return chatModelConfig;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@@ -3,6 +3,8 @@ package com.tencent.supersonic.chat.server.util;
|
|||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
import com.tencent.supersonic.chat.server.pojo.ChatContext;
|
||||||
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
import com.tencent.supersonic.chat.server.pojo.ParseContext;
|
||||||
|
import com.tencent.supersonic.common.pojo.ChatModelConfig;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
import com.tencent.supersonic.common.pojo.enums.Text2SQLType;
|
||||||
import com.tencent.supersonic.common.util.BeanMapper;
|
import com.tencent.supersonic.common.util.BeanMapper;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
import com.tencent.supersonic.headless.api.pojo.request.QueryNLReq;
|
||||||
@@ -24,9 +26,11 @@ public class QueryReqConverter {
|
|||||||
return queryNLReq;
|
return queryNLReq;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ChatModelConfig chatModelConfig =
|
||||||
|
ModelConfigHelper.getChatModelConfig(agent, ChatModelType.TEXT_TO_SQL);
|
||||||
boolean hasLLMTool = agent.containsLLMTool();
|
boolean hasLLMTool = agent.containsLLMTool();
|
||||||
boolean hasRuleTool = agent.containsRuleTool();
|
boolean hasRuleTool = agent.containsRuleTool();
|
||||||
boolean hasLLMConfig = Objects.nonNull(agent.getModelConfig());
|
boolean hasLLMConfig = chatModelConfig != null;
|
||||||
|
|
||||||
if (parseContext.isDisableLLM()) {
|
if (parseContext.isDisableLLM()) {
|
||||||
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
queryNLReq.setText2SQLType(Text2SQLType.ONLY_RULE);
|
||||||
@@ -45,7 +49,7 @@ public class QueryReqConverter {
|
|||||||
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
|
&& MapUtils.isNotEmpty(queryNLReq.getMapInfo().getDataSetElementMatches())) {
|
||||||
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
queryNLReq.setMapInfo(queryNLReq.getMapInfo());
|
||||||
}
|
}
|
||||||
queryNLReq.setModelConfig(agent.getModelConfig());
|
queryNLReq.setModelConfig(chatModelConfig);
|
||||||
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
queryNLReq.setPromptConfig(agent.getPromptConfig());
|
||||||
if (chatCtx != null) {
|
if (chatCtx != null) {
|
||||||
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
queryNLReq.setContextParseInfo(chatCtx.getParseInfo());
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
package com.tencent.supersonic.common.pojo.enums;
|
||||||
|
|
||||||
|
public enum ChatModelType {
|
||||||
|
TEXT_TO_SQL("Convert text query to SQL statement"), MULTI_TURN_REWRITE(
|
||||||
|
"Rewrite text query for multi-turn conversation"), MEMORY_REVIEW(
|
||||||
|
"Review memory in order to add few-shot examples"), RESPONSE_GENERATE(
|
||||||
|
"Generate readable response to the end user");
|
||||||
|
|
||||||
|
private String description;
|
||||||
|
|
||||||
|
ChatModelType(String description) {
|
||||||
|
this.description = description;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -155,13 +155,8 @@ public class S2VisitsDemo extends S2BaseDemo {
|
|||||||
agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计");
|
agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计");
|
||||||
agent.setStatus(1);
|
agent.setStatus(1);
|
||||||
agent.setEnableSearch(1);
|
agent.setEnableSearch(1);
|
||||||
agent.setExamples(Lists.newArrayList(
|
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
|
||||||
"近15天超音数访问次数汇总",
|
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
|
||||||
"按部门统计超音数的访问人数",
|
|
||||||
"对比alice和lucy的停留时长",
|
|
||||||
"过去30天访问次数最高的部门top3",
|
|
||||||
"近1个月总访问次数超过100次的部门有几个",
|
|
||||||
"过去半个月每个核心用户的总停留时长"));
|
|
||||||
AgentConfig agentConfig = new AgentConfig();
|
AgentConfig agentConfig = new AgentConfig();
|
||||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||||
|
|||||||
@@ -103,6 +103,21 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
|||||||
) ;
|
) ;
|
||||||
COMMENT ON TABLE s2_chat_memory IS 'chat memory table ';
|
COMMENT ON TABLE s2_chat_memory IS 'chat memory table ';
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `s2_chat_model`
|
||||||
|
(
|
||||||
|
id int AUTO_INCREMENT,
|
||||||
|
name varchar(100) null,
|
||||||
|
description varchar(500) null,
|
||||||
|
`config` varchar(500) NOT NULL ,
|
||||||
|
`created_at` TIMESTAMP NOT NULL ,
|
||||||
|
`created_by` varchar(100) NOT NULL ,
|
||||||
|
`updated_at` TIMESTAMP NOT NULL ,
|
||||||
|
`updated_by` varchar(100) NOT NULL,
|
||||||
|
`admin` varchar(500) NOT NULL,
|
||||||
|
`viewer` varchar(500) DEFAULT NULL,
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
); COMMENT ON TABLE s2_chat_model IS 'chat model table';
|
||||||
|
|
||||||
create table IF NOT EXISTS s2_user
|
create table IF NOT EXISTS s2_user
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT,
|
id INT AUTO_INCREMENT,
|
||||||
@@ -388,7 +403,6 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
|||||||
PRIMARY KEY (`id`)
|
PRIMARY KEY (`id`)
|
||||||
); COMMENT ON TABLE s2_agent IS 'agent information table';
|
); COMMENT ON TABLE s2_agent IS 'agent information table';
|
||||||
|
|
||||||
|
|
||||||
-------demo for semantic and chat
|
-------demo for semantic and chat
|
||||||
CREATE TABLE IF NOT EXISTS `s2_user_department` (
|
CREATE TABLE IF NOT EXISTS `s2_user_department` (
|
||||||
`user_name` varchar(200) NOT NULL,
|
`user_name` varchar(200) NOT NULL,
|
||||||
|
|||||||
@@ -210,6 +210,20 @@ CREATE TABLE IF NOT EXISTS `s2_chat_statistics` (
|
|||||||
KEY `commonIndex` (`question_id`)
|
KEY `commonIndex` (`question_id`)
|
||||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `s2_chat_model` (
|
||||||
|
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||||
|
`name` varchar(255) NOT NULL COMMENT '名称',
|
||||||
|
`description` varchar(500) DEFAULT NULL COMMENT '描述',
|
||||||
|
`config` text NOT NULL COMMENT '配置信息',
|
||||||
|
`created_at` datetime NOT NULL COMMENT '创建时间',
|
||||||
|
`created_by` varchar(100) NOT NULL COMMENT '创建人',
|
||||||
|
`updated_at` datetime NOT NULL COMMENT '更新时间',
|
||||||
|
`updated_by` varchar(100) NOT NULL COMMENT '更新人',
|
||||||
|
`admin` varchar(500) DEFAULT NULL,
|
||||||
|
`viewer` varchar(500) DEFAULT NULL,
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表';
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS `s2_database` (
|
CREATE TABLE IF NOT EXISTS `s2_database` (
|
||||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||||
`name` varchar(255) NOT NULL COMMENT '名称',
|
`name` varchar(255) NOT NULL COMMENT '名称',
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
|
|||||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||||
|
import com.tencent.supersonic.chat.server.service.ChatModelService;
|
||||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||||
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
|
||||||
@@ -33,6 +34,8 @@ public class BaseTest extends BaseApplication {
|
|||||||
protected ChatQueryService chatQueryService;
|
protected ChatQueryService chatQueryService;
|
||||||
@Autowired
|
@Autowired
|
||||||
protected AgentService agentService;
|
protected AgentService agentService;
|
||||||
|
@Autowired
|
||||||
|
protected ChatModelService chatModelService;
|
||||||
|
|
||||||
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
|
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
|
||||||
throws Exception {
|
throws Exception {
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package com.tencent.supersonic.evaluation;
|
|||||||
|
|
||||||
import com.alibaba.fastjson.JSONObject;
|
import com.alibaba.fastjson.JSONObject;
|
||||||
import com.google.common.collect.Lists;
|
import com.google.common.collect.Lists;
|
||||||
|
import com.google.common.collect.Maps;
|
||||||
|
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||||
import com.tencent.supersonic.chat.BaseTest;
|
import com.tencent.supersonic.chat.BaseTest;
|
||||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||||
@@ -9,12 +11,14 @@ import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
|||||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||||
|
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||||
import com.tencent.supersonic.util.DataUtils;
|
import com.tencent.supersonic.util.DataUtils;
|
||||||
import com.tencent.supersonic.util.LLMConfigUtils;
|
import com.tencent.supersonic.util.LLMConfigUtils;
|
||||||
import org.junit.jupiter.api.*;
|
import org.junit.jupiter.api.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.stream.Collectors;
|
import java.util.Map;
|
||||||
|
|
||||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||||
@Disabled
|
@Disabled
|
||||||
@@ -133,7 +137,13 @@ public class Text2SQLEval extends BaseTest {
|
|||||||
AgentConfig agentConfig = new AgentConfig();
|
AgentConfig agentConfig = new AgentConfig();
|
||||||
agentConfig.getTools().add(getLLMQueryTool());
|
agentConfig.getTools().add(getLLMQueryTool());
|
||||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||||
agent.setModelConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
|
ChatModel chatModel = new ChatModel();
|
||||||
|
chatModel.setName("Text2SQL LLM");
|
||||||
|
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
|
||||||
|
chatModel = chatModelService.createChatModel(chatModel, User.getFakeUser());
|
||||||
|
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
|
||||||
|
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
|
||||||
|
agent.setModelConfig(chatModelConfig);
|
||||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||||
agent.setMultiTurnConfig(multiTurnConfig);
|
agent.setMultiTurnConfig(multiTurnConfig);
|
||||||
|
|||||||
@@ -103,6 +103,21 @@ CREATE TABLE IF NOT EXISTS `s2_chat_memory` (
|
|||||||
) ;
|
) ;
|
||||||
COMMENT ON TABLE s2_chat_memory IS 'chat memory table ';
|
COMMENT ON TABLE s2_chat_memory IS 'chat memory table ';
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `s2_chat_model`
|
||||||
|
(
|
||||||
|
id int AUTO_INCREMENT,
|
||||||
|
name varchar(100) null,
|
||||||
|
description varchar(500) null,
|
||||||
|
`config` varchar(500) NOT NULL ,
|
||||||
|
`created_at` TIMESTAMP NOT NULL ,
|
||||||
|
`created_by` varchar(100) NOT NULL ,
|
||||||
|
`updated_at` TIMESTAMP NOT NULL ,
|
||||||
|
`updated_by` varchar(100) NOT NULL,
|
||||||
|
`admin` varchar(500) NOT NULL,
|
||||||
|
`viewer` varchar(500) DEFAULT NULL,
|
||||||
|
PRIMARY KEY (`id`)
|
||||||
|
); COMMENT ON TABLE s2_chat_model IS 'chat model table';
|
||||||
|
|
||||||
create table IF NOT EXISTS s2_user
|
create table IF NOT EXISTS s2_user
|
||||||
(
|
(
|
||||||
id INT AUTO_INCREMENT,
|
id INT AUTO_INCREMENT,
|
||||||
|
|||||||
Reference in New Issue
Block a user