[feature][chat]Refactor chat model config related codes.#1739

This commit is contained in:
jerryjzhang
2024-10-09 17:27:07 +08:00
parent 60b0a1a1a1
commit 248f4f83f6
53 changed files with 275 additions and 251 deletions

View File

@@ -224,7 +224,7 @@ public class CspiderDemo extends S2BaseDemo {
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser());
dataSetService.save(dataSetReq, User.getDefaultUser());
}
public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp,
@@ -296,6 +296,6 @@ public class CspiderDemo extends S2BaseDemo {
private void batchPushlishMetric() {
List<Long> ids = Lists.newArrayList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L);
metricService.batchPublish(ids, User.getFakeUser());
metricService.batchPublish(ids, User.getDefaultUser());
}
}

View File

@@ -4,9 +4,9 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
@@ -274,7 +274,7 @@ public class DuSQLDemo extends S2BaseDemo {
aggregateTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser());
dataSetService.save(dataSetReq, User.getDefaultUser());
}
public void addModelRela_1() {
@@ -334,16 +334,16 @@ public class DuSQLDemo extends S2BaseDemo {
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList());
AgentConfig agentConfig = new AgentConfig();
ToolConfig toolConfig = new ToolConfig();
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
agentConfig.getTools().add(llmParserTool);
toolConfig.getTools().add(llmParserTool);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
log.info("agent:{}", JsonUtil.toString(agent));
agentService.createAgent(agent, User.getFakeUser());
agentService.createAgent(agent, User.getDefaultUser());
}
}

View File

@@ -4,10 +4,10 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.*;
@@ -69,7 +69,7 @@ public class S2ArtistDemo extends S2BaseDemo {
tagObjectReq.setDomainId(singerDomain.getId());
tagObjectReq.setName("歌手");
tagObjectReq.setBizName("singer");
User user = User.getFakeUser();
User user = User.getDefaultUser();
return tagObjectService.create(tagObjectReq, user);
}
@@ -159,7 +159,7 @@ public class S2ArtistDemo extends S2BaseDemo {
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig);
DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getFakeUser());
DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getDefaultUser());
return dataSetResp.getId();
}
@@ -170,21 +170,21 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
AgentConfig agentConfig = new AgentConfig();
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(ruleQueryTool);
toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(llmParserTool);
toolConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agentService.createAgent(agent, User.getDefaultUser());
}
}

View File

@@ -3,10 +3,8 @@ package com.tencent.supersonic.demo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.chat.server.pojo.ChatModel;
import com.tencent.supersonic.chat.server.service.*;
import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
@@ -33,6 +31,7 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.service.TagObjectService;
import com.tencent.supersonic.headless.server.service.TermService;
import com.tencent.supersonic.headless.server.service.impl.DictWordService;
import dev.langchain4j.provider.ModelProvider;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -47,8 +46,9 @@ import java.util.stream.Collectors;
@Slf4j
public abstract class S2BaseDemo implements CommandLineRunner {
protected DatabaseResp demoDatabaseResp;
protected ChatModel chatModel;
protected User user = User.getFakeUser();
protected User user = User.getDefaultUser();
@Autowired
protected DatabaseService databaseService;
@Autowired
@@ -87,6 +87,8 @@ public abstract class S2BaseDemo implements CommandLineRunner {
protected CanvasService canvasService;
@Autowired
protected DictWordService dictWordService;
@Autowired
protected ChatModelService chatModelService;
@Value("${s2.demo.names:S2VisitsDemo}")
protected List<String> demoList;
@@ -96,6 +98,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
public void run(String... args) {
demoDatabaseResp = addDatabaseIfNotExist();
addChatModelIfNotExist();
if (demoList != null && demoList.contains(getClass().getSimpleName())) {
if (checkNeedToRun()) {
doRun();
@@ -108,7 +111,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
abstract boolean checkNeedToRun();
protected DatabaseResp addDatabaseIfNotExist() {
List<DatabaseResp> databaseList = databaseService.getDatabaseList(User.getFakeUser());
List<DatabaseResp> databaseList = databaseService.getDatabaseList(User.getDefaultUser());
if (!CollectionUtils.isEmpty(databaseList)) {
return databaseList.get(0);
}
@@ -130,6 +133,17 @@ public abstract class S2BaseDemo implements CommandLineRunner {
return databaseService.createOrUpdateDatabase(databaseReq, user);
}
protected void addChatModelIfNotExist() {
if (chatModelService.getChatModels().size() > 0) {
return;
}
chatModel = new ChatModel();
chatModel.setName("OpenAI模型DEMO");
chatModel.setDescription("由langchain4j社区提供仅用于体验单次请求最大token数1000");
chatModel.setConfig(ModelProvider.DEMO_CHAT_MODEL);
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
}
protected MetricResp getMetric(String bizName, ModelResp model) {
return metricService.getMetric(model.getId(), bizName);
}
@@ -160,7 +174,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
TagReq tagReq = new TagReq();
tagReq.setTagDefineType(tagDefineType);
tagReq.setItemId(itemId);
tagMetaService.create(tagReq, User.getFakeUser());
tagMetaService.create(tagReq, User.getDefaultUser());
}
protected DimensionResp getDimension(String bizName, ModelResp model) {

View File

@@ -2,16 +2,17 @@ package com.tencent.supersonic.demo;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
@@ -19,12 +20,7 @@ import com.tencent.supersonic.chat.server.plugin.build.webpage.WebPageQuery;
import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.enums.*;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
@@ -63,10 +59,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.*;
@Component
@Slf4j
@@ -146,7 +139,7 @@ public class S2VisitsDemo extends S2BaseDemo {
private void submitText(int chatId, int agentId, String queryText) {
chatQueryService.parseAndExecute(ChatParseReq.builder().chatId(chatId).agentId(agentId)
.queryText(queryText).user(User.getFakeUser()).disableLLM(true).build());
.queryText(queryText).user(User.getDefaultUser()).disableLLM(true).build());
}
private Integer addAgent(long dataSetId) {
@@ -157,23 +150,32 @@ public class S2VisitsDemo extends S2BaseDemo {
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
AgentConfig agentConfig = new AgentConfig();
// configure tools
ToolConfig toolConfig = new ToolConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(ruleQueryTool);
toolConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(llmParserTool);
toolConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// configure chat models
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
chatModelConfig.put(ChatModelType.MEMORY_REVIEW, chatModel.getId());
chatModelConfig.put(ChatModelType.RESPONSE_GENERATE, chatModel.getId());
chatModelConfig.put(ChatModelType.MULTI_TURN_REWRITE, chatModel.getId());
agent.setChatModelConfig(chatModelConfig);
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true);
agent.setMultiTurnConfig(multiTurnConfig);
Agent agentCreated = agentService.createAgent(agent, User.getFakeUser());
Agent agentCreated = agentService.createAgent(agent, User.getDefaultUser());
return agentCreated.getId();
}
@@ -460,7 +462,7 @@ public class S2VisitsDemo extends S2BaseDemo {
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);
dataSetReq.setDataSetDetail(dataSetDetail);
dataSetReq.setTypeEnum(TypeEnums.DATASET);
return dataSetService.save(dataSetReq, User.getFakeUser());
return dataSetService.save(dataSetReq, User.getDefaultUser());
}
public void addTerm(DomainResp s2Domain) {
@@ -469,7 +471,7 @@ public class S2VisitsDemo extends S2BaseDemo {
termReq.setDescription("指近10天");
termReq.setAlias(Lists.newArrayList("近一段时间"));
termReq.setDomainId(s2Domain.getId());
termService.saveOrUpdate(termReq, User.getFakeUser());
termService.saveOrUpdate(termReq, User.getDefaultUser());
}
public void addTerm_1(DomainResp s2Domain) {
@@ -478,7 +480,7 @@ public class S2VisitsDemo extends S2BaseDemo {
termReq.setDescription("用户为tom和lucy");
termReq.setAlias(Lists.newArrayList("VIP用户"));
termReq.setDomainId(s2Domain.getId());
termService.saveOrUpdate(termReq, User.getFakeUser());
termService.saveOrUpdate(termReq, User.getDefaultUser());
}
public void addAuthGroup_1(ModelResp stayTimeModel) {
@@ -553,7 +555,7 @@ public class S2VisitsDemo extends S2BaseDemo {
tagObjectReq.setDomainId(s2Domain.getId());
tagObjectReq.setName("用户");
tagObjectReq.setBizName("user");
User user = User.getFakeUser();
User user = User.getDefaultUser();
return tagObjectService.create(tagObjectReq, user);
}

View File

@@ -4,8 +4,8 @@ import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
@@ -24,14 +24,14 @@ public class SmallTalkDemo extends S2BaseDemo {
agent.setDescription("直接与大模型对话,验证连通性");
agent.setStatus(1);
agent.setEnableSearch(0);
AgentConfig agentConfig = new AgentConfig();
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
ToolConfig toolConfig = new ToolConfig();
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平"));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(true);
agent.setMultiTurnConfig(multiTurnConfig);
agentService.createAgent(agent, User.getFakeUser());
agentService.createAgent(agent, User.getDefaultUser());
}
@Override

View File

@@ -369,4 +369,21 @@ UPDATE `s2_dimension` SET `type` = 'identify' WHERE `type` in ('primary','foreig
alter table singer drop column imp_date;
--20240913
ALTER TABLE s2_model MODIFY COLUMN drill_down_dimensions TEXT DEFAULT NULL;
ALTER TABLE s2_model MODIFY COLUMN drill_down_dimensions TEXT DEFAULT NULL;
--20241009
CREATE TABLE IF NOT EXISTS `s2_chat_model` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`name` varchar(255) NOT NULL COMMENT '名称',
`description` varchar(500) DEFAULT NULL COMMENT '描述',
`config` text NOT NULL COMMENT '配置信息',
`created_at` datetime NOT NULL COMMENT '创建时间',
`created_by` varchar(100) NOT NULL COMMENT '创建人',
`updated_at` datetime NOT NULL COMMENT '更新时间',
`updated_by` varchar(100) NOT NULL COMMENT '更新人',
`admin` varchar(500) DEFAULT NULL,
`viewer` varchar(500) DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表';
ALTER TABLE s2_agent RENAME COLUMN config TO tool_config;
ALTER TABLE s2_agent RENAME COLUMN model_config TO chat_model_config;

View File

@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
description varchar(500) null,
status int null,
examples varchar(500) null,
config varchar(2000) null,
tool_config varchar(2000) null,
llm_config varchar(2000) null,
model_config varchar(6000) null,
chat_model_config varchar(6000) null,
prompt_config varchar(5000) null,
multi_turn_config varchar(2000) null,
visual_config varchar(2000) null,

View File

@@ -70,9 +70,9 @@ CREATE TABLE IF NOT EXISTS `s2_agent` (
`examples` TEXT COLLATE utf8_unicode_ci DEFAULT NULL,
`status` tinyint DEFAULT NULL,
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`tool_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`model_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`chat_model_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`prompt_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,