(improvement)(launcher)Refactor Demo impl and config

This commit is contained in:
jerryjzhang
2024-05-25 00:26:55 +08:00
parent be7629eb65
commit 5d16aa0ab4
10 changed files with 584 additions and 678 deletions

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.chat.server.parser;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository;
import com.tencent.supersonic.chat.server.pojo.ChatParseContext;
import com.tencent.supersonic.chat.server.util.QueryReqConverter;
@@ -34,14 +35,14 @@ import java.util.Collections;
@Slf4j
public class MultiTurnParser implements ChatParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger(MultiTurnParser.class);
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final PromptTemplate promptTemplate = PromptTemplate.from(
"You are a data product manager experienced in data requirements."
+ "Your will be provided with current and history questions asked by a user,"
+ "along with their mapped schema elements(metric, dimension and value), "
+ "please try understanding the semantics and rewrite a question"
+ "(keep relevant metrics, dimensions, values and date ranges)."
+ "(keep relevant entities, metrics, dimensions, values and date ranges)."
+ "Current Question: {{curtQuestion}} "
+ "Current Mapped Schema: {{curtSchema}} "
+ "History Question: {{histQuestion}} "
@@ -51,8 +52,11 @@ public class MultiTurnParser implements ChatParser {
@Override
public void parse(ChatParseContext chatParseContext, ParseResp parseResp) {
Environment environment = ContextUtils.getBean(Environment.class);
Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class);
if (!Boolean.TRUE.equals(multiTurn)) {
MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig();
Boolean globalMultiTurnConfig = environment.getProperty("s2.multi-turn.enable", Boolean.class);
Boolean multiTurnConfig = agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig;
if (!Boolean.TRUE.equals(multiTurnConfig)) {
return;
}

View File

@@ -1,231 +0,0 @@
package com.tencent.supersonic;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import com.tencent.supersonic.chat.server.agent.RuleParserTool;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
@Order(3)
public class ChatDemoLoader implements CommandLineRunner {
private User user = User.getFakeUser();
@Autowired
private ChatService chatService;
@Autowired
private ChatManageService chatManageService;
@Autowired
private AgentService agentService;
@Autowired
private SysParameterService sysParameterService;
@Value("${demo.enabled:false}")
private boolean demoEnabled;
@Value("${demo.nl2SqlLlm.enabled:true}")
private boolean demoEnabledNl2SqlLlm;
@Override
public void run(String... args) throws Exception {
if (!checkEnable()) {
log.info("skip load chat demo");
return;
}
doRun();
}
public void doRun() {
try {
addSysParameter();
Integer agentId = addAgent1();
addAgent2();
addAgent3();
//addAgent4();
addSampleChats(agentId);
addSampleChats2(agentId);
updateQueryScore(1);
updateQueryScore(4);
} catch (Exception e) {
log.error("Failed to add sample chats", e);
}
}
private void parseAndExecute(int chatId, int agentId, String queryText) throws Exception {
ChatParseReq chatParseReq = new ChatParseReq();
chatParseReq.setQueryText(queryText);
chatParseReq.setChatId(chatId);
chatParseReq.setAgentId(agentId);
chatParseReq.setUser(User.getFakeUser());
ParseResp parseResp = chatService.performParsing(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
log.info("parseResp.getSelectedParses() is empty");
return;
}
ChatExecuteReq executeReq = new ChatExecuteReq();
executeReq.setQueryId(parseResp.getQueryId());
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setQueryText(queryText);
executeReq.setChatId(parseResp.getChatId());
executeReq.setUser(User.getFakeUser());
executeReq.setSaveAnswer(true);
chatService.performExecution(executeReq);
}
public void addSampleChats(Integer agentId) throws Exception {
Long chatId = chatManageService.addChat(user, "样例对话1", agentId);
parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数");
parseAndExecute(chatId.intValue(), agentId, "按部门统计");
parseAndExecute(chatId.intValue(), agentId, "查询近30天");
}
public void addSampleChats2(Integer agentId) throws Exception {
Long chatId = chatManageService.addChat(user, "样例对话2", agentId);
parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy的访问次数");
parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
}
public void addSysParameter() {
SysParameter sysParameter = new SysParameter();
sysParameter.setId(1);
sysParameter.init();
sysParameterService.save(sysParameter);
}
private Integer addAgent1() {
Agent agent = new Agent();
agent.setId(1);
agent.setName("算指标");
agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数",
"对比alice和lucy的停留时长", "超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false);
agent.setMultiTurnConfig(multiTurnConfig);
agentService.createAgent(agent, User.getFakeUser());
return agent.getId();
}
private void addAgent2() {
Agent agent = new Agent();
agent.setId(2);
agent.setName("标签圈选");
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人"));
AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(-1L));
agentConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
}
private void addAgent3() {
Agent agent = new Agent();
agent.setId(3);
agent.setName("cspider");
agent.setDescription("cspider数据集的case展示");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("可用“mp4”格式且分辨率低于1000的歌曲的ID是什么",
"“孟加拉语”歌曲的平均评分和分辨率是多少?",
"找出所有至少有一首“英文”歌曲的艺术家的名字和作品数量。"));
AgentConfig agentConfig = new AgentConfig();
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(3L));
agentConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
}
private void addAgent4() {
Agent agent = new Agent();
agent.setId(4);
agent.setName("DuSQL 互联网企业");
agent.setDescription("DuSQL");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList());
AgentConfig agentConfig = new AgentConfig();
if (demoEnabledNl2SqlLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
agentConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
log.info("agent:{}", JsonUtil.toString(agent));
agentService.createAgent(agent, User.getFakeUser());
}
private void updateQueryScore(Integer queryId) {
chatManageService.updateFeedback(queryId, 5, "");
}
private boolean checkEnable() {
if (!demoEnabled) {
return false;
}
return HeadlessDemoLoader.isLoad();
}
}

View File

@@ -1,53 +0,0 @@
package com.tencent.supersonic;
import com.tencent.supersonic.headless.server.service.DomainService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
@Component
@Slf4j
@Order(1)
public class HeadlessDemoLoader implements CommandLineRunner {
private static boolean isLoad = false;
@Autowired
private DomainService domainService;
@Autowired
private ModelDemoDataLoader modelDataDemoLoader;
@Autowired
private BenchMarkDemoDataLoader benchMarkDemoLoader;
@Value("${demo.enabled:false}")
private boolean demoEnabled;
@Override
public void run(String... args) {
if (!checkLoadDemo()) {
log.info("skip load demo");
return;
}
modelDataDemoLoader.doRun();
benchMarkDemoLoader.doRun();
isLoad = true;
}
private boolean checkLoadDemo() {
if (!demoEnabled) {
return false;
}
return CollectionUtils.isEmpty(domainService.getDomainList());
}
public static boolean isLoad() {
return isLoad;
}
}

View File

@@ -1,4 +1,4 @@
package com.tencent.supersonic;
package com.tencent.supersonic.demo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
@@ -28,13 +28,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
@@ -44,31 +38,15 @@ import java.util.List;
@Component
@Slf4j
public class BenchMarkDemoDataLoader {
private User user = User.getFakeUser();
@Autowired
private DomainService domainService;
@Autowired
private ModelService modelService;
@Autowired
private ModelRelaService modelRelaService;
@Autowired
private DataSetService viewService;
@Autowired
private MetricService metricService;
@Autowired
private ModelDemoDataLoader modelDemoDataLoader;
public class CspiderDemo extends S2BaseDemo {
public void doRun() {
try {
DatabaseResp databaseResp = modelDemoDataLoader.tmpDatabaseResp;
DomainResp s2Domain = addDomain();
ModelResp genreModelResp = addModel_1(s2Domain, databaseResp);
ModelResp artistModelResp = addModel_2(s2Domain, databaseResp);
ModelResp filesModelResp = addModel_3(s2Domain, databaseResp);
ModelResp songModelResp = addModel_4(s2Domain, databaseResp);
ModelResp genreModelResp = addModel_1(s2Domain, demoDatabaseResp);
ModelResp artistModelResp = addModel_2(s2Domain, demoDatabaseResp);
ModelResp filesModelResp = addModel_3(s2Domain, demoDatabaseResp);
ModelResp songModelResp = addModel_4(s2Domain, demoDatabaseResp);
addDataSet_1(s2Domain);
addModelRela_1(s2Domain, genreModelResp, artistModelResp);
addModelRela_2(s2Domain, filesModelResp, artistModelResp);
@@ -79,7 +57,6 @@ public class BenchMarkDemoDataLoader {
} catch (Exception e) {
log.error("Failed to add bench mark demo data", e);
}
}
public DomainResp addDomain() {
@@ -222,17 +199,17 @@ public class BenchMarkDemoDataLoader {
}
public void addDataSet_1(DomainResp s2Domain) {
DataSetReq viewReq = new DataSetReq();
viewReq.setName("cspider");
viewReq.setBizName("singer");
viewReq.setDomainId(s2Domain.getId());
viewReq.setDescription("包含cspider数据集相关标签和指标信息");
viewReq.setAdmins(Lists.newArrayList("admin"));
List<DataSetModelConfig> viewModelConfigs = modelDemoDataLoader.getDataSetModelConfigs(s2Domain.getId());
DataSetDetail viewDetail = new DataSetDetail();
viewDetail.setDataSetModelConfigs(viewModelConfigs);
viewReq.setDataSetDetail(viewDetail);
viewReq.setTypeEnum(TypeEnums.DATASET);
DataSetReq dataSetReq = new DataSetReq();
dataSetReq.setName("cspider");
dataSetReq.setBizName("singer");
dataSetReq.setDomainId(s2Domain.getId());
dataSetReq.setDescription("包含cspider数据集相关标签和指标信息");
dataSetReq.setAdmins(Lists.newArrayList("admin"));
List<DataSetModelConfig> viewModelConfigs = getDataSetModelConfigs(s2Domain.getId());
DataSetDetail dsDetail = new DataSetDetail();
dsDetail.setDataSetModelConfigs(viewModelConfigs);
dataSetReq.setDataSetDetail(dsDetail);
dataSetReq.setTypeEnum(TypeEnums.DATASET);
QueryConfig queryConfig = new QueryConfig();
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig();
@@ -250,8 +227,8 @@ public class BenchMarkDemoDataLoader {
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
viewReq.setQueryConfig(queryConfig);
viewService.save(viewReq, User.getFakeUser());
dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser());
}
public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp, ModelResp artistModelResp) {

View File

@@ -1,13 +1,20 @@
package com.tencent.supersonic;
package com.tencent.supersonic.demo;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.demo.S2BaseDemo;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.Dim;
@@ -25,14 +32,8 @@ import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
@@ -42,21 +43,7 @@ import java.util.List;
@Component
@Slf4j
public class DuSQLDemoDataLoader {
private User user = User.getFakeUser();
@Autowired
private DomainService domainService;
@Autowired
private ModelService modelService;
@Autowired
private ModelRelaService modelRelaService;
@Autowired
private MetricService metricService;
@Autowired
private DataSetService viewService;
public class DuSQLDemo extends S2BaseDemo {
public void doRun() {
try {
@@ -70,6 +57,7 @@ public class DuSQLDemoDataLoader {
addModelRela_2();
addModelRela_3();
addModelRela_4();
addAgent();
} catch (Exception e) {
log.error("Failed to add bench mark demo data", e);
}
@@ -255,22 +243,22 @@ public class DuSQLDemoDataLoader {
}
public void addDataSet_1() {
DataSetReq viewReq = new DataSetReq();
viewReq.setName("DuSQL 互联网企业");
viewReq.setBizName("internet");
viewReq.setDomainId(4L);
viewReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等");
viewReq.setAdmins(Lists.newArrayList("admin"));
DataSetReq dataSetReq = new DataSetReq();
dataSetReq.setName("DuSQL 互联网企业");
dataSetReq.setBizName("internet");
dataSetReq.setDomainId(4L);
dataSetReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等");
dataSetReq.setAdmins(Lists.newArrayList("admin"));
List<DataSetModelConfig> viewModelConfigs = Lists.newArrayList(
new DataSetModelConfig(9L, Lists.newArrayList(16L, 17L, 18L, 19L, 20L), Lists.newArrayList(10L, 11L)),
new DataSetModelConfig(10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)),
new DataSetModelConfig(11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)),
new DataSetModelConfig(12L, Lists.newArrayList(24L), Lists.newArrayList(16L, 17L, 18L, 19L)));
DataSetDetail viewDetail = new DataSetDetail();
viewDetail.setDataSetModelConfigs(viewModelConfigs);
viewReq.setDataSetDetail(viewDetail);
viewReq.setTypeEnum(TypeEnums.DATASET);
DataSetDetail dsDetail = new DataSetDetail();
dsDetail.setDataSetModelConfigs(viewModelConfigs);
dataSetReq.setDataSetDetail(dsDetail);
dataSetReq.setTypeEnum(TypeEnums.DATASET);
QueryConfig queryConfig = new QueryConfig();
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
@@ -278,8 +266,8 @@ public class DuSQLDemoDataLoader {
timeDefaultConfig.setUnit(1);
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
viewReq.setQueryConfig(queryConfig);
viewService.save(viewReq, User.getFakeUser());
dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser());
}
public void addModelRela_1() {
@@ -330,4 +318,24 @@ public class DuSQLDemoDataLoader {
modelRelaService.save(modelRelaReq, user);
}
private void addAgent() {
Agent agent = new Agent();
agent.setName("DuSQL 互联网企业");
agent.setDescription("DuSQL");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList());
AgentConfig agentConfig = new AgentConfig();
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(4L));
agentConfig.getTools().add(llmParserTool);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
log.info("agent:{}", JsonUtil.toString(agent));
agentService.createAgent(agent, User.getFakeUser());
}
}

View File

@@ -0,0 +1,197 @@
package com.tencent.supersonic.demo;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.TagObjectReq;
import com.tencent.supersonic.headless.api.pojo.response.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Component
@Slf4j
@Order(2)
public class S2ArtistDemo extends S2BaseDemo {
public void doRun() {
try {
DomainResp singerDomain = addDomain();
TagObjectResp singerTagObject = addTagObjectSinger(singerDomain);
ModelResp singerModel = addModel(singerDomain, demoDatabaseResp, singerTagObject);
addTags(singerModel);
long dataSetId = addDataSet(singerDomain, singerModel);
addAgent(dataSetId);
} catch (Exception e) {
log.error("Failed to add model demo data", e);
}
}
private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception {
TagObjectReq tagObjectReq = new TagObjectReq();
tagObjectReq.setDomainId(singerDomain.getId());
tagObjectReq.setName("艺人");
tagObjectReq.setBizName("singer");
User user = User.getFakeUser();
return tagObjectService.create(tagObjectReq, user);
}
public DomainResp addDomain() {
DomainReq domainReq = new DomainReq();
domainReq.setName("艺人库");
domainReq.setBizName("supersonic");
domainReq.setParentId(0L);
domainReq.setStatus(StatusEnum.ONLINE.getCode());
domainReq.setViewers(Arrays.asList("admin", "tom", "jack"));
domainReq.setViewOrgs(Collections.singletonList("1"));
domainReq.setAdmins(Arrays.asList("admin", "alice"));
domainReq.setAdminOrgs(Collections.emptyList());
return domainService.createDomain(domainReq, user);
}
public ModelResp addModel(DomainResp singerDomain,
DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("艺人库");
modelReq.setBizName("singer");
modelReq.setDescription("艺人库");
modelReq.setDatabaseId(s2Database.getId());
modelReq.setDomainId(singerDomain.getId());
modelReq.setTagObjectId(singerTagObject.getId());
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail();
List<Identify> identifiers = new ArrayList<>();
Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1);
identify.setEntityNames(Lists.newArrayList("歌手", "艺人"));
identifiers.add(identify);
modelDetail.setIdentifiers(identifiers);
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
dimension1.setTypeParams(new DimensionTimeTypeParams());
dimensions.add(dimension1);
dimensions.add(new Dim("活跃区域", "act_area",
DimensionType.categorical.name(), 1, 1));
dimensions.add(new Dim("代表作", "song_name",
DimensionType.categorical.name(), 1));
dimensions.add(new Dim("风格", "genre",
DimensionType.categorical.name(), 1, 1));
modelDetail.setDimensions(dimensions);
Measure measure1 = new Measure("播放量", "js_play_cnt", "sum", 1);
Measure measure2 = new Measure("下载量", "down_cnt", "sum", 1);
Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1);
modelDetail.setMeasures(Lists.newArrayList(measure1, measure2, measure3));
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("select imp_date, singer_name, act_area, song_name, genre, "
+ "js_play_cnt, down_cnt, favor_cnt from singer");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
}
private void addTags(ModelResp model) {
addTag(dimensionService.getDimension("act_area", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("song_name", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("genre", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("singer_name", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(metricService.getMetric(model.getId(), "js_play_cnt").getId(),
TagDefineType.METRIC);
}
public long addDataSet(DomainResp singerDomain, ModelResp singerModel) {
DataSetReq dataSetReq = new DataSetReq();
dataSetReq.setName("艺人库");
dataSetReq.setBizName("singer");
dataSetReq.setDomainId(singerDomain.getId());
dataSetReq.setDescription("包含艺人相关标签和指标信息");
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
List<DataSetModelConfig> dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId());
DataSetDetail dataSetDetail = new DataSetDetail();
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);
dataSetReq.setDataSetDetail(dataSetDetail);
dataSetReq.setTypeEnum(TypeEnums.DATASET);
QueryConfig queryConfig = new QueryConfig();
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig();
tagTimeDefaultConfig.setTimeMode(TimeMode.LAST);
tagTimeDefaultConfig.setUnit(7);
tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig);
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
defaultDisplayInfo.setDimensionIds(dataSetModelConfigs.get(0).getDimensions());
MetricResp jsPlayCntMetric = getMetric("js_play_cnt", singerModel);
defaultDisplayInfo.setMetricIds(Lists.newArrayList(jsPlayCntMetric.getId()));
tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
timeDefaultConfig.setUnit(7);
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig);
DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getFakeUser());
return dataSetResp.getId();
}
private void addAgent(long dataSetId) {
Agent agent = new Agent();
agent.setName("做圈选");
agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人"));
AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setId("0");
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(llmParserTool);
}
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
int id = agentService.createAgent(agent, User.getFakeUser());
agent.setId(id);
}
}

View File

@@ -0,0 +1,176 @@
package com.tencent.supersonic.demo;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatManageService;
import com.tencent.supersonic.chat.server.service.ChatService;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.service.SysParameterService;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.TagReq;
import com.tencent.supersonic.headless.api.pojo.response.*;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.util.CollectionUtils;
import java.util.List;
import java.util.stream.Collectors;
@Slf4j
public abstract class S2BaseDemo implements CommandLineRunner {
protected DatabaseResp demoDatabaseResp;
protected User user = User.getFakeUser();
@Autowired
protected DatabaseService databaseService;
@Autowired
protected DomainService domainService;
@Autowired
protected ModelService modelService;
@Autowired
protected ModelRelaService modelRelaService;
@Autowired
protected DimensionService dimensionService;
@Autowired
protected MetricService metricService;
@Autowired
protected TagMetaService tagMetaService;
@Autowired
protected AuthService authService;
@Autowired
protected DataSetService dataSetService;
@Autowired
protected TermService termService;
@Autowired
protected PluginService pluginService;
@Autowired
protected DataSourceProperties dataSourceProperties;
@Autowired
protected TagObjectService tagObjectService;
@Autowired
protected ChatService chatService;
@Autowired
protected ChatManageService chatManageService;
@Autowired
protected AgentService agentService;
@Autowired
protected SysParameterService sysParameterService;
@Value("${s2.demo.names:S2VisitsDemo}")
protected List<String> demoList;
@Value("${s2.demo.enableLLM:true}")
protected boolean demoEnableLlm;
public void run(String... args) {
demoDatabaseResp = addDatabase();
if (demoList != null && demoList.contains(getClass().getSimpleName())) {
doRun();
}
}
abstract void doRun();
protected DatabaseResp addDatabase() {
String url = dataSourceProperties.getUrl();
DatabaseReq databaseReq = new DatabaseReq();
databaseReq.setName("数据实例");
databaseReq.setDescription("样例数据库实例");
if (StringUtils.isNotBlank(url)
&& url.toLowerCase().contains(DataType.MYSQL.getFeature().toLowerCase())) {
databaseReq.setType(DataType.MYSQL.getFeature());
databaseReq.setVersion("5.7");
} else {
databaseReq.setType(DataType.H2.getFeature());
}
databaseReq.setUrl(url);
databaseReq.setUsername(dataSourceProperties.getUsername());
databaseReq.setPassword(dataSourceProperties.getPassword());
return databaseService.createOrUpdateDatabase(databaseReq, user);
}
protected MetricResp getMetric(String bizName, ModelResp model) {
return metricService.getMetric(model.getId(), bizName);
}
protected List<DataSetModelConfig> getDataSetModelConfigs(Long domainId) {
List<DataSetModelConfig> dataSetModelConfigs = Lists.newArrayList();
List<ModelResp> modelByDomainIds =
modelService.getModelByDomainIds(Lists.newArrayList(domainId));
for (ModelResp modelResp : modelByDomainIds) {
DataSetModelConfig dataSetModelConfig = new DataSetModelConfig();
dataSetModelConfig.setId(modelResp.getId());
MetaFilter metaFilter = new MetaFilter();
metaFilter.setModelIds(Lists.newArrayList(modelResp.getId()));
List<Long> metrics = metricService.getMetrics(metaFilter)
.stream().map(MetricResp::getId).collect(Collectors.toList());
dataSetModelConfig.setMetrics(metrics);
List<Long> dimensions = dimensionService.getDimensions(metaFilter)
.stream().map(DimensionResp::getId).collect(Collectors.toList());
dataSetModelConfig.setMetrics(metrics);
dataSetModelConfig.setDimensions(dimensions);
dataSetModelConfigs.add(dataSetModelConfig);
}
return dataSetModelConfigs;
}
protected void parseAndExecute(int chatId, int agentId, String queryText) throws Exception {
ChatParseReq chatParseReq = new ChatParseReq();
chatParseReq.setQueryText(queryText);
chatParseReq.setChatId(chatId);
chatParseReq.setAgentId(agentId);
chatParseReq.setUser(User.getFakeUser());
ParseResp parseResp = chatService.performParsing(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
log.info("parseResp.getSelectedParses() is empty");
return;
}
ChatExecuteReq executeReq = new ChatExecuteReq();
executeReq.setQueryId(parseResp.getQueryId());
executeReq.setParseId(parseResp.getSelectedParses().get(0).getId());
executeReq.setQueryText(queryText);
executeReq.setChatId(parseResp.getChatId());
executeReq.setUser(User.getFakeUser());
executeReq.setSaveAnswer(true);
chatService.performExecution(executeReq);
}
protected void addTag(Long itemId, TagDefineType tagDefineType) {
TagReq tagReq = new TagReq();
tagReq.setTagDefineType(tagDefineType);
tagReq.setItemId(itemId);
tagMetaService.create(tagReq, User.getFakeUser());
}
protected DimensionResp getDimension(String bizName, ModelResp model) {
return dimensionService.getDimension(bizName, model.getId());
}
protected RelateDimension getRelateDimension(List<Long> dimensionIds) {
RelateDimension relateDimension = new RelateDimension();
for (Long id : dimensionIds) {
relateDimension.getDrillDownDimensions().add(new DrillDownDimension(id));
}
return relateDimension;
}
protected void updateQueryScore(Integer queryId) {
chatManageService.updateFeedback(queryId, 5, "");
}
}

View File

@@ -1,196 +1,130 @@
package com.tencent.supersonic;
package com.tencent.supersonic.demo;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup;
import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule;
import com.tencent.supersonic.auth.api.authorization.service.AuthService;
import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.plugin.Plugin;
import com.tencent.supersonic.chat.server.plugin.PluginParseConfig;
import com.tencent.supersonic.chat.server.plugin.build.ParamOption;
import com.tencent.supersonic.chat.server.plugin.build.WebBase;
import com.tencent.supersonic.chat.server.service.PluginService;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.SysParameter;
import com.tencent.supersonic.common.pojo.enums.*;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.Field;
import com.tencent.supersonic.headless.api.pojo.FieldParam;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByFieldParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMetricParams;
import com.tencent.supersonic.headless.api.pojo.MetricParam;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.TagObjectReq;
import com.tencent.supersonic.headless.api.pojo.request.TagReq;
import com.tencent.supersonic.headless.api.pojo.request.TermReq;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.TagObjectResp;
import com.tencent.supersonic.headless.server.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DatabaseService;
import com.tencent.supersonic.headless.server.service.DimensionService;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.TagMetaService;
import com.tencent.supersonic.headless.server.service.TagObjectService;
import com.tencent.supersonic.headless.server.service.TermService;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.enums.*;
import com.tencent.supersonic.headless.api.pojo.request.*;
import com.tencent.supersonic.headless.api.pojo.response.*;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
@Component
@Slf4j
public class ModelDemoDataLoader {
protected DatabaseResp tmpDatabaseResp = null;
private User user = User.getFakeUser();
@Autowired
private DatabaseService databaseService;
@Autowired
private DomainService domainService;
@Autowired
private ModelService modelService;
@Autowired
private ModelRelaService modelRelaService;
@Autowired
private DimensionService dimensionService;
@Autowired
private MetricService metricService;
@Autowired
private AuthService authService;
@Autowired
private DataSetService dataSetService;
@Autowired
private DataSourceProperties dataSourceProperties;
@Autowired
private TagObjectService tagObjectService;
@Autowired
private TagMetaService tagMetaService;
@Autowired
private TermService termService;
@Autowired
private PluginService pluginService;
@Order(1)
public class S2VisitsDemo extends S2BaseDemo {
public void doRun() {
try {
DatabaseResp databaseResp = addDatabase();
tmpDatabaseResp = databaseResp;
// create domain
DomainResp s2Domain = addDomain();
TagObjectResp s2TagObject = addTagObjectUser(s2Domain);
ModelResp userModel = addModel_1(s2Domain, databaseResp, s2TagObject);
ModelResp pvUvModel = addModel_2(s2Domain, databaseResp);
DimensionResp userDimension = getDimension("user_name", userModel);
// create models
ModelResp userModel = addModel_1(s2Domain, demoDatabaseResp, s2TagObject);
ModelResp pvUvModel = addModel_2(s2Domain, demoDatabaseResp);
ModelResp stayTimeModel = addModel_3(s2Domain, demoDatabaseResp);
addModelRela_1(s2Domain, userModel, pvUvModel);
addModelRela_2(s2Domain, userModel, stayTimeModel);
addTags(userModel);
//create metrics and dimensions
DimensionResp departmentDimension = getDimension("department", userModel);
MetricResp metricUv = addMetric_uv(userModel, departmentDimension);
MetricResp metricPv = getMetric("pv", pvUvModel);
addMetric_pv_avg(metricPv, metricUv, departmentDimension, pvUvModel);
ModelResp stayTimeModel = addModel_3(s2Domain, databaseResp);
addModelRela_1(s2Domain, userModel, pvUvModel);
addModelRela_2(s2Domain, userModel, stayTimeModel);
DomainResp singerDomain = addDomain_2();
TagObjectResp singerTagObject = addTagObjectSinger(singerDomain);
ModelResp singerModel = addModel_4(singerDomain, databaseResp, singerTagObject);
DimensionResp pageDimension = getDimension("page", stayTimeModel);
updateDimension(stayTimeModel, pageDimension);
DimensionResp userDimension = getDimension("user_name", userModel);
updateMetric(stayTimeModel, departmentDimension, userDimension);
addTags(userModel, singerModel);
updateMetric_pv(pvUvModel, departmentDimension, userDimension, metricPv);
DataSetResp s2DataSet = addDataSet_1(s2Domain);
addDataSet_2(singerDomain, singerModel);
//create data set
DataSetResp s2DataSet = addDataSet(s2Domain);
addAuthGroup_1(stayTimeModel);
addAuthGroup_2(stayTimeModel);
//create terms and plugin
addTerm(s2Domain);
addTerm_1(s2Domain);
addPlugin_1(s2DataSet, userDimension, userModel);
addPlugin(s2DataSet, userDimension, userModel);
addSysParameter();
//create agent
Integer agentId = addAgent(s2DataSet.getId());
addSampleChats(agentId);
updateQueryScore(1);
updateQueryScore(4);
} catch (Exception e) {
log.error("Failed to add model demo data", e);
log.error("Failed to add S2Visits demo data", e);
}
}
public void addSampleChats(Integer agentId) throws Exception {
Long chatId = chatManageService.addChat(user, "样例对话1", agentId);
parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数");
parseAndExecute(chatId.intValue(), agentId, "按部门统计");
parseAndExecute(chatId.intValue(), agentId, "查询近30天");
parseAndExecute(chatId.intValue(), agentId, "alice 停留时长");
parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy的访问次数");
parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门");
}
private TagObjectResp addTagObjectUser(DomainResp s2Domain) throws Exception {
TagObjectReq tagObjectReq = new TagObjectReq();
tagObjectReq.setDomainId(s2Domain.getId());
tagObjectReq.setName("用户");
tagObjectReq.setBizName("user");
User user = User.getFakeUser();
return tagObjectService.create(tagObjectReq, user);
public void addSysParameter() {
SysParameter sysParameter = new SysParameter();
sysParameter.setId(1);
sysParameter.init();
sysParameterService.save(sysParameter);
}
private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception {
TagObjectReq tagObjectReq = new TagObjectReq();
tagObjectReq.setDomainId(singerDomain.getId());
tagObjectReq.setName("艺人");
tagObjectReq.setBizName("singer");
User user = User.getFakeUser();
return tagObjectService.create(tagObjectReq, user);
private Integer addAgent(long dataSetId) {
Agent agent = new Agent();
agent.setName("算指标");
agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数",
"对比alice和lucy的停留时长", "超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setId("0");
ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(ruleQueryTool);
if (demoEnableLlm) {
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId));
agentConfig.getTools().add(llmParserTool);
}
public DatabaseResp addDatabase() {
String url = dataSourceProperties.getUrl();
DatabaseReq databaseReq = new DatabaseReq();
databaseReq.setName("数据实例");
databaseReq.setDescription("样例数据库实例");
if (StringUtils.isNotBlank(url)
&& url.toLowerCase().contains(DataType.MYSQL.getFeature().toLowerCase())) {
databaseReq.setType(DataType.MYSQL.getFeature());
databaseReq.setVersion("5.7");
} else {
databaseReq.setType(DataType.H2.getFeature());
}
databaseReq.setUrl(url);
databaseReq.setUsername(dataSourceProperties.getUsername());
databaseReq.setPassword(dataSourceProperties.getPassword());
return databaseService.createOrUpdateDatabase(databaseReq, user);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false);
agent.setMultiTurnConfig(multiTurnConfig);
int id = agentService.createAgent(agent, User.getFakeUser());
agent.setId(id);
return agent.getId();
}
public DomainResp addDomain() {
@@ -348,82 +282,9 @@ public class ModelDemoDataLoader {
modelRelaService.save(modelRelaReq, user);
}
public DomainResp addDomain_2() {
DomainReq domainReq = new DomainReq();
domainReq.setName("艺人库");
domainReq.setBizName("supersonic");
domainReq.setParentId(0L);
domainReq.setStatus(StatusEnum.ONLINE.getCode());
domainReq.setViewers(Arrays.asList("admin", "tom", "jack"));
domainReq.setViewOrgs(Collections.singletonList("1"));
domainReq.setAdmins(Arrays.asList("admin", "alice"));
domainReq.setAdminOrgs(Collections.emptyList());
return domainService.createDomain(domainReq, user);
}
public ModelResp addModel_4(DomainResp singerDomain,
DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("艺人库");
modelReq.setBizName("singer");
modelReq.setDescription("艺人库");
modelReq.setDatabaseId(s2Database.getId());
modelReq.setDomainId(singerDomain.getId());
modelReq.setTagObjectId(singerTagObject.getId());
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail();
List<Identify> identifiers = new ArrayList<>();
Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1);
identify.setEntityNames(Lists.newArrayList("歌手", "艺人"));
identifiers.add(identify);
modelDetail.setIdentifiers(identifiers);
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
dimension1.setTypeParams(new DimensionTimeTypeParams());
dimensions.add(dimension1);
dimensions.add(new Dim("活跃区域", "act_area",
DimensionType.categorical.name(), 1, 1));
dimensions.add(new Dim("代表作", "song_name",
DimensionType.categorical.name(), 1));
dimensions.add(new Dim("风格", "genre",
DimensionType.categorical.name(), 1, 1));
modelDetail.setDimensions(dimensions);
Measure measure1 = new Measure("播放量", "js_play_cnt", "sum", 1);
Measure measure2 = new Measure("下载量", "down_cnt", "sum", 1);
Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1);
modelDetail.setMeasures(Lists.newArrayList(measure1, measure2, measure3));
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("select imp_date, singer_name, act_area, song_name, genre, "
+ "js_play_cnt, down_cnt, favor_cnt from singer");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
}
private void addTags(ModelResp userModel, ModelResp singerModel) {
addTag(dimensionService.getDimension("department", userModel.getId()).getId(),
private void addTags(ModelResp model) {
addTag(dimensionService.getDimension("department", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("act_area", singerModel.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("song_name", singerModel.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("genre", singerModel.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("singer_name", singerModel.getId()).getId(),
TagDefineType.DIMENSION);
addTag(metricService.getMetric(singerModel.getId(), "js_play_cnt").getId(),
TagDefineType.METRIC);
}
private void addTag(Long itemId, TagDefineType tagDefineType) {
TagReq tagReq = new TagReq();
tagReq.setTagDefineType(tagDefineType);
tagReq.setItemId(itemId);
tagMetaService.create(tagReq, User.getFakeUser());
}
public void updateDimension(ModelResp stayTimeModel, DimensionResp pageDimension) throws Exception {
@@ -535,7 +396,7 @@ public class ModelDemoDataLoader {
return metricService.createMetric(metricReq, user);
}
public DataSetResp addDataSet_1(DomainResp s2Domain) {
public DataSetResp addDataSet(DomainResp s2Domain) {
DataSetReq dataSetReq = new DataSetReq();
dataSetReq.setName("超音数");
dataSetReq.setBizName("s2");
@@ -558,39 +419,6 @@ public class ModelDemoDataLoader {
return dataSetService.save(dataSetReq, User.getFakeUser());
}
public void addDataSet_2(DomainResp singerDomain, ModelResp singerModel) {
DataSetReq dataSetReq = new DataSetReq();
dataSetReq.setName("艺人库");
dataSetReq.setBizName("singer");
dataSetReq.setDomainId(singerDomain.getId());
dataSetReq.setDescription("包含艺人相关标签和指标信息");
dataSetReq.setAdmins(Lists.newArrayList("admin", "jack"));
List<DataSetModelConfig> dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId());
DataSetDetail dataSetDetail = new DataSetDetail();
dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs);
dataSetReq.setDataSetDetail(dataSetDetail);
dataSetReq.setTypeEnum(TypeEnums.DATASET);
QueryConfig queryConfig = new QueryConfig();
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig();
tagTimeDefaultConfig.setTimeMode(TimeMode.LAST);
tagTimeDefaultConfig.setUnit(7);
tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig);
DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo();
defaultDisplayInfo.setDimensionIds(dataSetModelConfigs.get(0).getDimensions());
MetricResp jsPlayCntMetric = getMetric("js_play_cnt", singerModel);
defaultDisplayInfo.setMetricIds(Lists.newArrayList(jsPlayCntMetric.getId()));
tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo);
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
timeDefaultConfig.setUnit(7);
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
dataSetReq.setQueryConfig(queryConfig);
dataSetService.save(dataSetReq, User.getFakeUser());
}
public void addTerm(DomainResp s2Domain) {
TermReq termReq = new TermReq();
@@ -641,7 +469,7 @@ public class ModelDemoDataLoader {
authService.addOrUpdateAuthGroup(authGroupReq);
}
private void addPlugin_1(DataSetResp s2DataSet, DimensionResp userDimension,
private void addPlugin(DataSetResp s2DataSet, DimensionResp userDimension,
ModelResp userModel) {
Plugin plugin1 = new Plugin();
plugin1.setType("WEB_PAGE");
@@ -666,42 +494,13 @@ public class ModelDemoDataLoader {
pluginService.createPlugin(plugin1, user);
}
private RelateDimension getRelateDimension(List<Long> dimensionIds) {
RelateDimension relateDimension = new RelateDimension();
for (Long id : dimensionIds) {
relateDimension.getDrillDownDimensions().add(new DrillDownDimension(id));
}
return relateDimension;
}
private DimensionResp getDimension(String bizName, ModelResp model) {
return dimensionService.getDimension(bizName, model.getId());
}
private MetricResp getMetric(String bizName, ModelResp model) {
return metricService.getMetric(model.getId(), bizName);
}
protected List<DataSetModelConfig> getDataSetModelConfigs(Long domainId) {
List<DataSetModelConfig> dataSetModelConfigs = Lists.newArrayList();
List<ModelResp> modelByDomainIds =
modelService.getModelByDomainIds(Lists.newArrayList(domainId));
for (ModelResp modelResp : modelByDomainIds) {
DataSetModelConfig dataSetModelConfig = new DataSetModelConfig();
dataSetModelConfig.setId(modelResp.getId());
MetaFilter metaFilter = new MetaFilter();
metaFilter.setModelIds(Lists.newArrayList(modelResp.getId()));
List<Long> metrics = metricService.getMetrics(metaFilter)
.stream().map(MetricResp::getId).collect(Collectors.toList());
dataSetModelConfig.setMetrics(metrics);
List<Long> dimensions = dimensionService.getDimensions(metaFilter)
.stream().map(DimensionResp::getId).collect(Collectors.toList());
dataSetModelConfig.setMetrics(metrics);
dataSetModelConfig.setDimensions(dimensions);
dataSetModelConfigs.add(dataSetModelConfig);
}
return dataSetModelConfigs;
private TagObjectResp addTagObjectUser(DomainResp s2Domain) throws Exception {
TagObjectReq tagObjectReq = new TagObjectReq();
tagObjectReq.setDomainId(s2Domain.getId());
tagObjectReq.setName("用户");
tagObjectReq.setBizName("user");
User user = User.getFakeUser();
return tagObjectService.create(tagObjectReq, user);
}
}

View File

@@ -9,7 +9,6 @@ spring:
h2:
console:
path: /h2-console/semantic
# enabled web
enabled: true
datasource:
driver-class-name: org.h2.Driver
@@ -31,16 +30,10 @@ authentication:
header:
key: Authorization
demo:
enabled: true
query:
optimizer:
enable: true
multi:
turn: false
time:
threshold: 100
@@ -70,6 +63,11 @@ text2sql:
num: 1
s2:
demo:
names: S2VisitsDemo,S2ArtistDemo
enableLLM: true
multi-turn:
enable: false
langchain4j:
#1.chat-model
chat-model:

View File

@@ -1,8 +1,14 @@
server:
port: 9080
compression:
enabled: true
min-response-size: 1024
mime-types: application/javascript,application/json,application/xml,text/html,text/xml,text/plain,text/css,image/*
spring:
h2:
console:
path: /h2-console/chat
# enabled web
path: /h2-console/semantic
enabled: true
datasource:
driver-class-name: org.h2.Driver
@@ -12,32 +18,56 @@ spring:
username: root
password: semantic
demo:
enabled: true
nl2SqlLlm:
enabled: false
server:
port: 9080
mybatis:
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
authentication:
enable: false
enable: true
exclude:
path: /api/auth/user/register,/api/auth/user/login
token:
http:
header:
key: Authorization
semantic:
url:
prefix: http://127.0.0.1:9081
query:
optimizer:
enable: true
time:
threshold: 100
mybatis:
mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml
dimension:
topn: 20
metric:
topn: 20
corrector:
additional:
information: true
pyllm:
url: http://127.0.0.1:9092
llm:
parser:
url: ${pyllm.url}
embedding:
url: ${pyllm.url}
functionCall:
url: ${pyllm.url}
text2sql:
example:
num: 1
#langchain4j config
s2:
demo:
names: S2VisitsDemo,S2ArtistDemo
enableLLM: true
multi-turn:
enable: false
langchain4j:
#1.chat-model
chat-model:
@@ -58,6 +88,7 @@ s2:
# inProcess:
# modelPath: /data/model.onnx
# vocabularyPath: /data/onnx_vocab.txt
# shibing624/text2vec-base-chinese
#2.2 open_ai
# embedding-model:
# provider: open_ai