mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-15 06:27:21 +00:00
[feature][chat]Refactor chat model config related codes.#1739
This commit is contained in:
@@ -224,7 +224,7 @@ public class CspiderDemo extends S2BaseDemo {
|
||||
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
|
||||
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
|
||||
dataSetReq.setQueryConfig(queryConfig);
|
||||
dataSetService.save(dataSetReq, User.getFakeUser());
|
||||
dataSetService.save(dataSetReq, User.getDefaultUser());
|
||||
}
|
||||
|
||||
public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp,
|
||||
@@ -296,6 +296,6 @@ public class CspiderDemo extends S2BaseDemo {
|
||||
|
||||
private void batchPushlishMetric() {
|
||||
List<Long> ids = Lists.newArrayList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L);
|
||||
metricService.batchPublish(ids, User.getFakeUser());
|
||||
metricService.batchPublish(ids, User.getDefaultUser());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,9 +4,9 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.pojo.JoinCondition;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
@@ -274,7 +274,7 @@ public class DuSQLDemo extends S2BaseDemo {
|
||||
aggregateTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
|
||||
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
|
||||
dataSetReq.setQueryConfig(queryConfig);
|
||||
dataSetService.save(dataSetReq, User.getFakeUser());
|
||||
dataSetService.save(dataSetReq, User.getDefaultUser());
|
||||
}
|
||||
|
||||
public void addModelRela_1() {
|
||||
@@ -334,16 +334,16 @@ public class DuSQLDemo extends S2BaseDemo {
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList());
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setId("1");
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
|
||||
agentConfig.getTools().add(llmParserTool);
|
||||
toolConfig.getTools().add(llmParserTool);
|
||||
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
log.info("agent:{}", JsonUtil.toString(agent));
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
agentService.createAgent(agent, User.getDefaultUser());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,10 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.headless.api.pojo.*;
|
||||
@@ -69,7 +69,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
tagObjectReq.setDomainId(singerDomain.getId());
|
||||
tagObjectReq.setName("歌手");
|
||||
tagObjectReq.setBizName("singer");
|
||||
User user = User.getFakeUser();
|
||||
User user = User.getDefaultUser();
|
||||
return tagObjectService.create(tagObjectReq, user);
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
queryConfig.setDetailTypeDefaultConfig(detailTypeDefaultConfig);
|
||||
queryConfig.setAggregateTypeDefaultConfig(aggregateTypeDefaultConfig);
|
||||
dataSetReq.setQueryConfig(queryConfig);
|
||||
DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getFakeUser());
|
||||
DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getDefaultUser());
|
||||
return dataSetResp.getId();
|
||||
}
|
||||
|
||||
@@ -170,21 +170,21 @@ public class S2ArtistDemo extends S2BaseDemo {
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("国风流派歌手", "港台歌手", "周杰伦流派"));
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setId("0");
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
agentConfig.getTools().add(ruleQueryTool);
|
||||
toolConfig.getTools().add(ruleQueryTool);
|
||||
|
||||
if (demoEnableLlm) {
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setId("1");
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
agentConfig.getTools().add(llmParserTool);
|
||||
toolConfig.getTools().add(llmParserTool);
|
||||
}
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
agentService.createAgent(agent, User.getDefaultUser());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,10 +3,8 @@ package com.tencent.supersonic.demo;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
|
||||
import com.tencent.supersonic.chat.server.service.AgentService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatManageService;
|
||||
import com.tencent.supersonic.chat.server.service.ChatQueryService;
|
||||
import com.tencent.supersonic.chat.server.service.PluginService;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.chat.server.service.*;
|
||||
import com.tencent.supersonic.common.service.SystemConfigService;
|
||||
import com.tencent.supersonic.common.util.AESEncryptionUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
|
||||
@@ -33,6 +31,7 @@ import com.tencent.supersonic.headless.server.service.TagMetaService;
|
||||
import com.tencent.supersonic.headless.server.service.TagObjectService;
|
||||
import com.tencent.supersonic.headless.server.service.TermService;
|
||||
import com.tencent.supersonic.headless.server.service.impl.DictWordService;
|
||||
import dev.langchain4j.provider.ModelProvider;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.lang3.StringUtils;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
@@ -47,8 +46,9 @@ import java.util.stream.Collectors;
|
||||
@Slf4j
|
||||
public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
protected DatabaseResp demoDatabaseResp;
|
||||
protected ChatModel chatModel;
|
||||
|
||||
protected User user = User.getFakeUser();
|
||||
protected User user = User.getDefaultUser();
|
||||
@Autowired
|
||||
protected DatabaseService databaseService;
|
||||
@Autowired
|
||||
@@ -87,6 +87,8 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
protected CanvasService canvasService;
|
||||
@Autowired
|
||||
protected DictWordService dictWordService;
|
||||
@Autowired
|
||||
protected ChatModelService chatModelService;
|
||||
|
||||
@Value("${s2.demo.names:S2VisitsDemo}")
|
||||
protected List<String> demoList;
|
||||
@@ -96,6 +98,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
|
||||
public void run(String... args) {
|
||||
demoDatabaseResp = addDatabaseIfNotExist();
|
||||
addChatModelIfNotExist();
|
||||
if (demoList != null && demoList.contains(getClass().getSimpleName())) {
|
||||
if (checkNeedToRun()) {
|
||||
doRun();
|
||||
@@ -108,7 +111,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
abstract boolean checkNeedToRun();
|
||||
|
||||
protected DatabaseResp addDatabaseIfNotExist() {
|
||||
List<DatabaseResp> databaseList = databaseService.getDatabaseList(User.getFakeUser());
|
||||
List<DatabaseResp> databaseList = databaseService.getDatabaseList(User.getDefaultUser());
|
||||
if (!CollectionUtils.isEmpty(databaseList)) {
|
||||
return databaseList.get(0);
|
||||
}
|
||||
@@ -130,6 +133,17 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
return databaseService.createOrUpdateDatabase(databaseReq, user);
|
||||
}
|
||||
|
||||
protected void addChatModelIfNotExist() {
|
||||
if (chatModelService.getChatModels().size() > 0) {
|
||||
return;
|
||||
}
|
||||
chatModel = new ChatModel();
|
||||
chatModel.setName("OpenAI模型DEMO");
|
||||
chatModel.setDescription("由langchain4j社区提供仅用于体验,单次请求最大token数1000");
|
||||
chatModel.setConfig(ModelProvider.DEMO_CHAT_MODEL);
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
|
||||
}
|
||||
|
||||
protected MetricResp getMetric(String bizName, ModelResp model) {
|
||||
return metricService.getMetric(model.getId(), bizName);
|
||||
}
|
||||
@@ -160,7 +174,7 @@ public abstract class S2BaseDemo implements CommandLineRunner {
|
||||
TagReq tagReq = new TagReq();
|
||||
tagReq.setTagDefineType(tagDefineType);
|
||||
tagReq.setItemId(itemId);
|
||||
tagMetaService.create(tagReq, User.getFakeUser());
|
||||
tagMetaService.create(tagReq, User.getDefaultUser());
|
||||
}
|
||||
|
||||
protected DimensionResp getDimension(String bizName, ModelResp model) {
|
||||
|
||||
@@ -2,16 +2,17 @@ package com.tencent.supersonic.demo;
|
||||
|
||||
import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
|
||||
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
|
||||
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.ChatPlugin;
|
||||
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
|
||||
@@ -19,12 +20,7 @@ import com.tencent.supersonic.chat.server.plugin.build.webpage.WebPageQuery;
|
||||
import com.tencent.supersonic.chat.server.plugin.build.webservice.WebServiceQuery;
|
||||
import com.tencent.supersonic.common.pojo.JoinCondition;
|
||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
|
||||
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||
import com.tencent.supersonic.common.pojo.enums.*;
|
||||
import com.tencent.supersonic.common.util.JsonUtil;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
|
||||
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
|
||||
@@ -63,10 +59,7 @@ import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
@Component
|
||||
@Slf4j
|
||||
@@ -146,7 +139,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
|
||||
private void submitText(int chatId, int agentId, String queryText) {
|
||||
chatQueryService.parseAndExecute(ChatParseReq.builder().chatId(chatId).agentId(agentId)
|
||||
.queryText(queryText).user(User.getFakeUser()).disableLLM(true).build());
|
||||
.queryText(queryText).user(User.getDefaultUser()).disableLLM(true).build());
|
||||
}
|
||||
|
||||
private Integer addAgent(long dataSetId) {
|
||||
@@ -157,23 +150,32 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
agent.setEnableSearch(1);
|
||||
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
|
||||
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
// configure tools
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
RuleParserTool ruleQueryTool = new RuleParserTool();
|
||||
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
|
||||
ruleQueryTool.setId("0");
|
||||
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
agentConfig.getTools().add(ruleQueryTool);
|
||||
toolConfig.getTools().add(ruleQueryTool);
|
||||
if (demoEnableLlm) {
|
||||
LLMParserTool llmParserTool = new LLMParserTool();
|
||||
llmParserTool.setId("1");
|
||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
|
||||
agentConfig.getTools().add(llmParserTool);
|
||||
toolConfig.getTools().add(llmParserTool);
|
||||
}
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
// configure chat models
|
||||
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
|
||||
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
|
||||
chatModelConfig.put(ChatModelType.MEMORY_REVIEW, chatModel.getId());
|
||||
chatModelConfig.put(ChatModelType.RESPONSE_GENERATE, chatModel.getId());
|
||||
chatModelConfig.put(ChatModelType.MULTI_TURN_REWRITE, chatModel.getId());
|
||||
agent.setChatModelConfig(chatModelConfig);
|
||||
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(true);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
Agent agentCreated = agentService.createAgent(agent, User.getFakeUser());
|
||||
Agent agentCreated = agentService.createAgent(agent, User.getDefaultUser());
|
||||
return agentCreated.getId();
|
||||
}
|
||||
|
||||
@@ -460,7 +462,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);
|
||||
dataSetReq.setDataSetDetail(dataSetDetail);
|
||||
dataSetReq.setTypeEnum(TypeEnums.DATASET);
|
||||
return dataSetService.save(dataSetReq, User.getFakeUser());
|
||||
return dataSetService.save(dataSetReq, User.getDefaultUser());
|
||||
}
|
||||
|
||||
public void addTerm(DomainResp s2Domain) {
|
||||
@@ -469,7 +471,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
termReq.setDescription("指近10天");
|
||||
termReq.setAlias(Lists.newArrayList("近一段时间"));
|
||||
termReq.setDomainId(s2Domain.getId());
|
||||
termService.saveOrUpdate(termReq, User.getFakeUser());
|
||||
termService.saveOrUpdate(termReq, User.getDefaultUser());
|
||||
}
|
||||
|
||||
public void addTerm_1(DomainResp s2Domain) {
|
||||
@@ -478,7 +480,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
termReq.setDescription("用户为tom和lucy");
|
||||
termReq.setAlias(Lists.newArrayList("VIP用户"));
|
||||
termReq.setDomainId(s2Domain.getId());
|
||||
termService.saveOrUpdate(termReq, User.getFakeUser());
|
||||
termService.saveOrUpdate(termReq, User.getDefaultUser());
|
||||
}
|
||||
|
||||
public void addAuthGroup_1(ModelResp stayTimeModel) {
|
||||
@@ -553,7 +555,7 @@ public class S2VisitsDemo extends S2BaseDemo {
|
||||
tagObjectReq.setDomainId(s2Domain.getId());
|
||||
tagObjectReq.setName("用户");
|
||||
tagObjectReq.setBizName("user");
|
||||
User user = User.getFakeUser();
|
||||
User user = User.getDefaultUser();
|
||||
return tagObjectService.create(tagObjectReq, user);
|
||||
}
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import com.alibaba.fastjson.JSONObject;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.core.annotation.Order;
|
||||
import org.springframework.stereotype.Component;
|
||||
@@ -24,14 +24,14 @@ public class SmallTalkDemo extends S2BaseDemo {
|
||||
agent.setDescription("直接与大模型对话,验证连通性");
|
||||
agent.setStatus(1);
|
||||
agent.setEnableSearch(0);
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平"));
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(true);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
|
||||
agentService.createAgent(agent, User.getFakeUser());
|
||||
agentService.createAgent(agent, User.getDefaultUser());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -369,4 +369,21 @@ UPDATE `s2_dimension` SET `type` = 'identify' WHERE `type` in ('primary','foreig
|
||||
alter table singer drop column imp_date;
|
||||
|
||||
--20240913
|
||||
ALTER TABLE s2_model MODIFY COLUMN drill_down_dimensions TEXT DEFAULT NULL;
|
||||
ALTER TABLE s2_model MODIFY COLUMN drill_down_dimensions TEXT DEFAULT NULL;
|
||||
|
||||
--20241009
|
||||
CREATE TABLE IF NOT EXISTS `s2_chat_model` (
|
||||
`id` bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
`name` varchar(255) NOT NULL COMMENT '名称',
|
||||
`description` varchar(500) DEFAULT NULL COMMENT '描述',
|
||||
`config` text NOT NULL COMMENT '配置信息',
|
||||
`created_at` datetime NOT NULL COMMENT '创建时间',
|
||||
`created_by` varchar(100) NOT NULL COMMENT '创建人',
|
||||
`updated_at` datetime NOT NULL COMMENT '更新时间',
|
||||
`updated_by` varchar(100) NOT NULL COMMENT '更新人',
|
||||
`admin` varchar(500) DEFAULT NULL,
|
||||
`viewer` varchar(500) DEFAULT NULL,
|
||||
PRIMARY KEY (`id`)
|
||||
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='对话大模型实例表';
|
||||
ALTER TABLE s2_agent RENAME COLUMN config TO tool_config;
|
||||
ALTER TABLE s2_agent RENAME COLUMN model_config TO chat_model_config;
|
||||
|
||||
@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
description varchar(500) null,
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
tool_config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
model_config varchar(6000) null,
|
||||
chat_model_config varchar(6000) null,
|
||||
prompt_config varchar(5000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
|
||||
@@ -70,9 +70,9 @@ CREATE TABLE IF NOT EXISTS `s2_agent` (
|
||||
`examples` TEXT COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`status` tinyint DEFAULT NULL,
|
||||
`model` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`tool_config` varchar(6000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`llm_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`model_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`chat_model_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`prompt_config` text COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`multi_turn_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
`visual_config` varchar(2000) COLLATE utf8_unicode_ci DEFAULT NULL,
|
||||
|
||||
@@ -7,10 +7,10 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
||||
import com.tencent.supersonic.chat.BaseTest;
|
||||
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
|
||||
import com.tencent.supersonic.chat.server.agent.Agent;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.AgentToolType;
|
||||
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
|
||||
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
|
||||
import com.tencent.supersonic.chat.server.agent.ToolConfig;
|
||||
import com.tencent.supersonic.chat.server.pojo.ChatModel;
|
||||
import com.tencent.supersonic.common.pojo.enums.ChatModelType;
|
||||
import com.tencent.supersonic.util.DataUtils;
|
||||
@@ -49,7 +49,7 @@ public class Text2SQLEval extends BaseTest {
|
||||
QueryResult result = submitNewChat("近30天总访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 1;
|
||||
assert result.getQueryColumns().get(0).getName().contains("访问次数");
|
||||
assert result.getTextResult().contains("511");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -58,8 +58,8 @@ public class Text2SQLEval extends BaseTest {
|
||||
QueryResult result = submitNewChat("近30日每天的访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("date");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
assert result.getQueryResults().size() == 30;
|
||||
assert result.getTextResult().contains("date");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -68,9 +68,11 @@ public class Text2SQLEval extends BaseTest {
|
||||
QueryResult result = submitNewChat("过去30天每个部门的汇总访问次数", agentId);
|
||||
durations.add(System.currentTimeMillis() - start);
|
||||
assert result.getQueryColumns().size() == 2;
|
||||
assert result.getQueryColumns().get(0).getName().equalsIgnoreCase("部门");
|
||||
assert result.getQueryColumns().get(1).getName().contains("访问次数");
|
||||
assert result.getQueryResults().size() == 4;
|
||||
assert result.getTextResult().contains("marketing");
|
||||
assert result.getTextResult().contains("sales");
|
||||
assert result.getTextResult().contains("strategy");
|
||||
assert result.getTextResult().contains("HR");
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -134,16 +136,16 @@ public class Text2SQLEval extends BaseTest {
|
||||
public Agent getLLMAgent(boolean enableMultiturn) {
|
||||
Agent agent = new Agent();
|
||||
agent.setName("Agent for Test");
|
||||
AgentConfig agentConfig = new AgentConfig();
|
||||
agentConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
|
||||
ToolConfig toolConfig = new ToolConfig();
|
||||
toolConfig.getTools().add(getLLMQueryTool());
|
||||
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
|
||||
ChatModel chatModel = new ChatModel();
|
||||
chatModel.setName("Text2SQL LLM");
|
||||
chatModel.setConfig(LLMConfigUtils.getLLMConfig(LLMConfigUtils.LLMType.OLLAMA_LLAMA3));
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getFakeUser());
|
||||
chatModel = chatModelService.createChatModel(chatModel, User.getDefaultUser());
|
||||
Map<ChatModelType, Integer> chatModelConfig = Maps.newHashMap();
|
||||
chatModelConfig.put(ChatModelType.TEXT_TO_SQL, chatModel.getId());
|
||||
agent.setModelConfig(chatModelConfig);
|
||||
agent.setChatModelConfig(chatModelConfig);
|
||||
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
|
||||
multiTurnConfig.setEnableMultiTurn(enableMultiturn);
|
||||
agent.setMultiTurnConfig(multiTurnConfig);
|
||||
|
||||
@@ -34,7 +34,7 @@ public class BaseTest extends BaseApplication {
|
||||
private DomainRepository domainRepository;
|
||||
|
||||
protected SemanticQueryResp queryBySql(String sql) throws Exception {
|
||||
return queryBySql(sql, User.getFakeUser());
|
||||
return queryBySql(sql, User.getDefaultUser());
|
||||
}
|
||||
|
||||
protected SemanticQueryResp queryBySql(String sql, User user) throws Exception {
|
||||
|
||||
@@ -22,7 +22,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText("对比alice和lucy的访问次数");
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setUser(User.getDefaultUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("超音数数据集"));
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
|
||||
@@ -36,7 +36,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText("风格为流行的艺人");
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setUser(User.getDefaultUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("艺人库"));
|
||||
queryMapReq.setQueryDataType(QueryDataType.TAG);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
@@ -48,7 +48,7 @@ public class MetaDiscoveryTest extends BaseTest {
|
||||
QueryMapReq queryMapReq = new QueryMapReq();
|
||||
queryMapReq.setQueryText("超音数访问次数最高的部门");
|
||||
queryMapReq.setTopN(10);
|
||||
queryMapReq.setUser(User.getFakeUser());
|
||||
queryMapReq.setUser(User.getDefaultUser());
|
||||
queryMapReq.setDataSetNames(Collections.singletonList("超音数"));
|
||||
queryMapReq.setQueryDataType(QueryDataType.METRIC);
|
||||
MapInfoResp mapMeta = chatLayerService.map(queryMapReq);
|
||||
|
||||
@@ -23,7 +23,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
@@ -33,7 +33,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
QueryMetricReq queryMetricReq = new QueryMetricReq();
|
||||
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
@@ -44,7 +44,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
|
||||
@@ -52,7 +52,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
|
||||
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
|
||||
assertThrows(IllegalArgumentException.class,
|
||||
() -> queryByMetric(queryMetricReq, User.getFakeUser()));
|
||||
() -> queryByMetric(queryMetricReq, User.getDefaultUser()));
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -61,7 +61,7 @@ public class QueryByMetricTest extends BaseTest {
|
||||
queryMetricReq.setDomainId(1L);
|
||||
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
|
||||
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getFakeUser());
|
||||
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(6, queryResp.getResultList().size());
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
QueryStructReq queryStructReq =
|
||||
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
|
||||
SemanticQueryResp semanticQueryResp =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(3, semanticQueryResp.getColumns().size());
|
||||
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
|
||||
assertEquals("用户", firstColumn.getName());
|
||||
@@ -64,7 +64,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
public void testSumQuery() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(null);
|
||||
SemanticQueryResp semanticQueryResp =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(1, semanticQueryResp.getColumns().size());
|
||||
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
|
||||
assertEquals("访问次数", queryColumn.getName());
|
||||
@@ -75,7 +75,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
public void testGroupByQuery() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
|
||||
SemanticQueryResp result =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(2, result.getColumns().size());
|
||||
QueryColumn firstColumn = result.getColumns().get(0);
|
||||
QueryColumn secondColumn = result.getColumns().get(1);
|
||||
@@ -97,7 +97,7 @@ public class QueryByStructTest extends BaseTest {
|
||||
queryStructReq.setDimensionFilters(dimensionFilters);
|
||||
|
||||
SemanticQueryResp result =
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.queryByReq(queryStructReq, User.getDefaultUser());
|
||||
assertEquals(2, result.getColumns().size());
|
||||
QueryColumn firstColumn = result.getColumns().get(0);
|
||||
QueryColumn secondColumn = result.getColumns().get(1);
|
||||
|
||||
@@ -15,7 +15,7 @@ public class QueryDimensionTest extends BaseTest {
|
||||
queryDimValueReq.setBizName("department");
|
||||
|
||||
SemanticQueryResp queryResp =
|
||||
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
|
||||
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getDefaultUser());
|
||||
Assert.assertNotNull(queryResp.getResultList());
|
||||
Assert.assertEquals(4, queryResp.getResultList().size());
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ public class QueryRuleTest extends BaseTest {
|
||||
@Autowired
|
||||
private QueryRuleService queryRuleService;
|
||||
|
||||
private User user = User.getFakeUser();
|
||||
private User user = User.getDefaultUser();
|
||||
|
||||
public QueryRuleReq addSystemRule() {
|
||||
QueryRuleReq queryRuleReq = new QueryRuleReq();
|
||||
|
||||
@@ -18,7 +18,7 @@ public class TagObjectTest extends BaseTest {
|
||||
|
||||
@Test
|
||||
void testCreateTagObject() throws Exception {
|
||||
User user = User.getFakeUser();
|
||||
User user = User.getDefaultUser();
|
||||
TagObjectReq tagObjectReq = newTagObjectReq();
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, user);
|
||||
tagObjectService.delete(tagObjectResp.getId(), user, false);
|
||||
@@ -27,24 +27,25 @@ public class TagObjectTest extends BaseTest {
|
||||
@Test
|
||||
void testUpdateTagObject() throws Exception {
|
||||
TagObjectReq tagObjectReq = newTagObjectReq();
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser());
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
|
||||
TagObjectReq tagObjectReqUpdate = new TagObjectReq();
|
||||
BeanUtils.copyProperties(tagObjectResp, tagObjectReqUpdate);
|
||||
tagObjectReqUpdate.setName("艺人1");
|
||||
tagObjectService.update(tagObjectReqUpdate, User.getFakeUser());
|
||||
tagObjectService.update(tagObjectReqUpdate, User.getDefaultUser());
|
||||
TagObjectResp tagObject =
|
||||
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getFakeUser());
|
||||
tagObjectService.delete(tagObject.getId(), User.getFakeUser(), false);
|
||||
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getDefaultUser());
|
||||
tagObjectService.delete(tagObject.getId(), User.getDefaultUser(), false);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testQueryTagObject() throws Exception {
|
||||
TagObjectReq tagObjectReq = newTagObjectReq();
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getFakeUser());
|
||||
TagObjectResp tagObjectResp = tagObjectService.create(tagObjectReq, User.getDefaultUser());
|
||||
TagObjectFilter filter = new TagObjectFilter();
|
||||
List<TagObjectResp> tagObjects = tagObjectService.getTagObjects(filter, User.getFakeUser());
|
||||
List<TagObjectResp> tagObjects =
|
||||
tagObjectService.getTagObjects(filter, User.getDefaultUser());
|
||||
tagObjects.size();
|
||||
tagObjectService.delete(tagObjectResp.getId(), User.getFakeUser(), false);
|
||||
tagObjectService.delete(tagObjectResp.getId(), User.getDefaultUser(), false);
|
||||
}
|
||||
|
||||
private TagObjectReq newTagObjectReq() {
|
||||
|
||||
@@ -21,7 +21,7 @@ public class TagTest extends BaseTest {
|
||||
ItemValueReq itemValueReq = new ItemValueReq();
|
||||
itemValueReq.setId(1L);
|
||||
ItemValueResp itemValueResp =
|
||||
tagQueryService.queryTagValue(itemValueReq, User.getFakeUser());
|
||||
tagQueryService.queryTagValue(itemValueReq, User.getDefaultUser());
|
||||
Assertions.assertNotNull(itemValueResp);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ public class TranslateTest extends BaseTest {
|
||||
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
|
||||
SemanticTranslateResp explain = semanticLayerService.translate(
|
||||
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
|
||||
User.getFakeUser());
|
||||
User.getDefaultUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
assertTrue(explain.getQuerySQL().contains("department"));
|
||||
@@ -30,7 +30,7 @@ public class TranslateTest extends BaseTest {
|
||||
public void testStructExplain() throws Exception {
|
||||
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
|
||||
SemanticTranslateResp explain =
|
||||
semanticLayerService.translate(queryStructReq, User.getFakeUser());
|
||||
semanticLayerService.translate(queryStructReq, User.getDefaultUser());
|
||||
assertNotNull(explain);
|
||||
assertNotNull(explain.getQuerySQL());
|
||||
assertTrue(explain.getQuerySQL().contains("department"));
|
||||
|
||||
@@ -19,7 +19,7 @@ public class DataUtils {
|
||||
public static final Integer tagAgentId = 2;
|
||||
public static final Integer ONE_TURNS_CHAT_ID = 10;
|
||||
public static final Integer MULTI_TURNS_CHAT_ID = 11;
|
||||
private static final User user_test = User.getFakeUser();
|
||||
private static final User user_test = User.getDefaultUser();
|
||||
|
||||
public static User getUser() {
|
||||
return user_test;
|
||||
|
||||
@@ -388,9 +388,9 @@ CREATE TABLE IF NOT EXISTS s2_agent
|
||||
description varchar(500) null,
|
||||
status int null,
|
||||
examples varchar(500) null,
|
||||
config varchar(2000) null,
|
||||
tool_config varchar(2000) null,
|
||||
llm_config varchar(2000) null,
|
||||
model_config varchar(6000) null,
|
||||
chat_model_config varchar(6000) null,
|
||||
prompt_config varchar(5000) null,
|
||||
multi_turn_config varchar(2000) null,
|
||||
visual_config varchar(2000) null,
|
||||
|
||||
Reference in New Issue
Block a user