From 5d16aa0ab47934e62e3cb74127cf56d79bed54cc Mon Sep 17 00:00:00 2001 From: jerryjzhang Date: Sat, 25 May 2024 00:26:55 +0800 Subject: [PATCH] (improvement)(launcher)Refactor Demo impl and config --- .../chat/server/parser/MultiTurnParser.java | 12 +- .../tencent/supersonic/ChatDemoLoader.java | 231 ----------- .../supersonic/HeadlessDemoLoader.java | 53 --- .../CspiderDemo.java} | 61 +-- .../DuSQLDemo.java} | 76 ++-- .../tencent/supersonic/demo/S2ArtistDemo.java | 197 +++++++++ .../tencent/supersonic/demo/S2BaseDemo.java | 176 ++++++++ .../S2VisitsDemo.java} | 381 +++++------------- .../src/main/resources/application-local.yaml | 12 +- .../src/test/resources/application-local.yaml | 63 ++- 10 files changed, 584 insertions(+), 678 deletions(-) delete mode 100644 launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java delete mode 100644 launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java rename launchers/standalone/src/main/java/com/tencent/supersonic/{BenchMarkDemoDataLoader.java => demo/CspiderDemo.java} (88%) rename launchers/standalone/src/main/java/com/tencent/supersonic/{DuSQLDemoDataLoader.java => demo/DuSQLDemo.java} (88%) create mode 100644 launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java create mode 100644 launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java rename launchers/standalone/src/main/java/com/tencent/supersonic/{ModelDemoDataLoader.java => demo/S2VisitsDemo.java} (58%) diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java index 378c68220..be90e7e37 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/MultiTurnParser.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.chat.server.parser; +import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; import com.tencent.supersonic.chat.server.persistence.repository.ChatQueryRepository; import com.tencent.supersonic.chat.server.pojo.ChatParseContext; import com.tencent.supersonic.chat.server.util.QueryReqConverter; @@ -34,14 +35,14 @@ import java.util.Collections; @Slf4j public class MultiTurnParser implements ChatParser { - private static final Logger keyPipelineLog = LoggerFactory.getLogger(MultiTurnParser.class); + private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final PromptTemplate promptTemplate = PromptTemplate.from( "You are a data product manager experienced in data requirements." + "Your will be provided with current and history questions asked by a user," + "along with their mapped schema elements(metric, dimension and value), " + "please try understanding the semantics and rewrite a question" - + "(keep relevant metrics, dimensions, values and date ranges)." + + "(keep relevant entities, metrics, dimensions, values and date ranges)." + "Current Question: {{curtQuestion}} " + "Current Mapped Schema: {{curtSchema}} " + "History Question: {{histQuestion}} " @@ -51,8 +52,11 @@ public class MultiTurnParser implements ChatParser { @Override public void parse(ChatParseContext chatParseContext, ParseResp parseResp) { Environment environment = ContextUtils.getBean(Environment.class); - Boolean multiTurn = environment.getProperty("multi.turn", Boolean.class); - if (!Boolean.TRUE.equals(multiTurn)) { + MultiTurnConfig agentMultiTurnConfig = chatParseContext.getAgent().getMultiTurnConfig(); + Boolean globalMultiTurnConfig = environment.getProperty("s2.multi-turn.enable", Boolean.class); + + Boolean multiTurnConfig = agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig; + if (!Boolean.TRUE.equals(multiTurnConfig)) { return; } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java deleted file mode 100644 index c5406a98f..000000000 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java +++ /dev/null @@ -1,231 +0,0 @@ -package com.tencent.supersonic; - -import com.alibaba.fastjson.JSONObject; -import com.google.common.collect.Lists; -import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; -import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; -import com.tencent.supersonic.chat.server.agent.Agent; -import com.tencent.supersonic.chat.server.agent.AgentConfig; -import com.tencent.supersonic.chat.server.agent.AgentToolType; -import com.tencent.supersonic.chat.server.agent.LLMParserTool; -import com.tencent.supersonic.chat.server.agent.MultiTurnConfig; -import com.tencent.supersonic.chat.server.agent.RuleParserTool; -import com.tencent.supersonic.chat.server.service.AgentService; -import com.tencent.supersonic.chat.server.service.ChatManageService; -import com.tencent.supersonic.chat.server.service.ChatService; -import com.tencent.supersonic.common.pojo.SysParameter; -import com.tencent.supersonic.common.service.SysParameterService; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.response.ParseResp; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.CommandLineRunner; -import org.springframework.core.annotation.Order; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - -@Component -@Slf4j -@Order(3) -public class ChatDemoLoader implements CommandLineRunner { - - private User user = User.getFakeUser(); - @Autowired - private ChatService chatService; - @Autowired - private ChatManageService chatManageService; - @Autowired - private AgentService agentService; - @Autowired - private SysParameterService sysParameterService; - - @Value("${demo.enabled:false}") - private boolean demoEnabled; - - @Value("${demo.nl2SqlLlm.enabled:true}") - private boolean demoEnabledNl2SqlLlm; - - @Override - public void run(String... args) throws Exception { - if (!checkEnable()) { - log.info("skip load chat demo"); - return; - } - doRun(); - } - - public void doRun() { - try { - addSysParameter(); - Integer agentId = addAgent1(); - addAgent2(); - addAgent3(); - //addAgent4(); - addSampleChats(agentId); - addSampleChats2(agentId); - updateQueryScore(1); - updateQueryScore(4); - } catch (Exception e) { - log.error("Failed to add sample chats", e); - } - } - - private void parseAndExecute(int chatId, int agentId, String queryText) throws Exception { - ChatParseReq chatParseReq = new ChatParseReq(); - chatParseReq.setQueryText(queryText); - chatParseReq.setChatId(chatId); - chatParseReq.setAgentId(agentId); - chatParseReq.setUser(User.getFakeUser()); - ParseResp parseResp = chatService.performParsing(chatParseReq); - if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) { - log.info("parseResp.getSelectedParses() is empty"); - return; - } - ChatExecuteReq executeReq = new ChatExecuteReq(); - executeReq.setQueryId(parseResp.getQueryId()); - executeReq.setParseId(parseResp.getSelectedParses().get(0).getId()); - executeReq.setQueryText(queryText); - executeReq.setChatId(parseResp.getChatId()); - executeReq.setUser(User.getFakeUser()); - executeReq.setSaveAnswer(true); - chatService.performExecution(executeReq); - } - - public void addSampleChats(Integer agentId) throws Exception { - Long chatId = chatManageService.addChat(user, "样例对话1", agentId); - - parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数"); - parseAndExecute(chatId.intValue(), agentId, "按部门统计"); - parseAndExecute(chatId.intValue(), agentId, "查询近30天"); - } - - public void addSampleChats2(Integer agentId) throws Exception { - Long chatId = chatManageService.addChat(user, "样例对话2", agentId); - - parseAndExecute(chatId.intValue(), agentId, "alice 停留时长"); - parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy的访问次数"); - parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门"); - } - - public void addSysParameter() { - SysParameter sysParameter = new SysParameter(); - sysParameter.setId(1); - sysParameter.init(); - sysParameterService.save(sysParameter); - } - - private Integer addAgent1() { - Agent agent = new Agent(); - agent.setId(1); - agent.setName("算指标"); - agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计"); - agent.setStatus(1); - agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数", - "对比alice和lucy的停留时长", "超音数访问次数最高的部门")); - AgentConfig agentConfig = new AgentConfig(); - RuleParserTool ruleQueryTool = new RuleParserTool(); - ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); - ruleQueryTool.setId("0"); - ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L)); - agentConfig.getTools().add(ruleQueryTool); - if (demoEnabledNl2SqlLlm) { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(-1L)); - agentConfig.getTools().add(llmParserTool); - } - agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false); - agent.setMultiTurnConfig(multiTurnConfig); - agentService.createAgent(agent, User.getFakeUser()); - return agent.getId(); - } - - private void addAgent2() { - Agent agent = new Agent(); - agent.setId(2); - agent.setName("标签圈选"); - agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选"); - agent.setStatus(1); - agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人")); - AgentConfig agentConfig = new AgentConfig(); - RuleParserTool ruleQueryTool = new RuleParserTool(); - ruleQueryTool.setId("0"); - ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); - ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L)); - agentConfig.getTools().add(ruleQueryTool); - - if (demoEnabledNl2SqlLlm) { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(-1L)); - agentConfig.getTools().add(llmParserTool); - } - agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - agentService.createAgent(agent, User.getFakeUser()); - } - - private void addAgent3() { - Agent agent = new Agent(); - agent.setId(3); - agent.setName("cspider"); - agent.setDescription("cspider数据集的case展示"); - agent.setStatus(1); - agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList("可用“mp4”格式且分辨率低于1000的歌曲的ID是什么?", - "“孟加拉语”歌曲的平均评分和分辨率是多少?", - "找出所有至少有一首“英文”歌曲的艺术家的名字和作品数量。")); - AgentConfig agentConfig = new AgentConfig(); - if (demoEnabledNl2SqlLlm) { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(3L)); - agentConfig.getTools().add(llmParserTool); - } - - agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - agentService.createAgent(agent, User.getFakeUser()); - } - - private void addAgent4() { - Agent agent = new Agent(); - agent.setId(4); - agent.setName("DuSQL 互联网企业"); - agent.setDescription("DuSQL"); - agent.setStatus(1); - agent.setEnableSearch(1); - agent.setExamples(Lists.newArrayList()); - AgentConfig agentConfig = new AgentConfig(); - - if (demoEnabledNl2SqlLlm) { - LLMParserTool llmParserTool = new LLMParserTool(); - llmParserTool.setId("1"); - llmParserTool.setType(AgentToolType.NL2SQL_LLM); - llmParserTool.setDataSetIds(Lists.newArrayList(4L)); - agentConfig.getTools().add(llmParserTool); - } - - agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); - log.info("agent:{}", JsonUtil.toString(agent)); - agentService.createAgent(agent, User.getFakeUser()); - } - - private void updateQueryScore(Integer queryId) { - chatManageService.updateFeedback(queryId, 5, ""); - } - - private boolean checkEnable() { - if (!demoEnabled) { - return false; - } - return HeadlessDemoLoader.isLoad(); - } - -} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java deleted file mode 100644 index 1cde40f9c..000000000 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java +++ /dev/null @@ -1,53 +0,0 @@ -package com.tencent.supersonic; - -import com.tencent.supersonic.headless.server.service.DomainService; -import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.boot.CommandLineRunner; -import org.springframework.core.annotation.Order; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - -@Component -@Slf4j -@Order(1) -public class HeadlessDemoLoader implements CommandLineRunner { - - private static boolean isLoad = false; - - @Autowired - private DomainService domainService; - - @Autowired - private ModelDemoDataLoader modelDataDemoLoader; - - @Autowired - private BenchMarkDemoDataLoader benchMarkDemoLoader; - - @Value("${demo.enabled:false}") - private boolean demoEnabled; - - @Override - public void run(String... args) { - if (!checkLoadDemo()) { - log.info("skip load demo"); - return; - } - modelDataDemoLoader.doRun(); - benchMarkDemoLoader.doRun(); - isLoad = true; - } - - private boolean checkLoadDemo() { - if (!demoEnabled) { - return false; - } - return CollectionUtils.isEmpty(domainService.getDomainList()); - } - - public static boolean isLoad() { - return isLoad; - } - -} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/BenchMarkDemoDataLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java similarity index 88% rename from launchers/standalone/src/main/java/com/tencent/supersonic/BenchMarkDemoDataLoader.java rename to launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java index 72d0c0690..f6e031bd4 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/BenchMarkDemoDataLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic; +package com.tencent.supersonic.demo; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; @@ -28,13 +28,7 @@ import com.tencent.supersonic.headless.api.pojo.request.DataSetReq; import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; import com.tencent.supersonic.headless.api.pojo.response.DomainResp; import com.tencent.supersonic.headless.api.pojo.response.ModelResp; -import com.tencent.supersonic.headless.server.service.DomainService; -import com.tencent.supersonic.headless.server.service.MetricService; -import com.tencent.supersonic.headless.server.service.ModelRelaService; -import com.tencent.supersonic.headless.server.service.ModelService; -import com.tencent.supersonic.headless.server.service.DataSetService; import lombok.extern.slf4j.Slf4j; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.ArrayList; @@ -44,31 +38,15 @@ import java.util.List; @Component @Slf4j -public class BenchMarkDemoDataLoader { - - private User user = User.getFakeUser(); - - @Autowired - private DomainService domainService; - @Autowired - private ModelService modelService; - @Autowired - private ModelRelaService modelRelaService; - @Autowired - private DataSetService viewService; - @Autowired - private MetricService metricService; - @Autowired - private ModelDemoDataLoader modelDemoDataLoader; +public class CspiderDemo extends S2BaseDemo { public void doRun() { try { - DatabaseResp databaseResp = modelDemoDataLoader.tmpDatabaseResp; DomainResp s2Domain = addDomain(); - ModelResp genreModelResp = addModel_1(s2Domain, databaseResp); - ModelResp artistModelResp = addModel_2(s2Domain, databaseResp); - ModelResp filesModelResp = addModel_3(s2Domain, databaseResp); - ModelResp songModelResp = addModel_4(s2Domain, databaseResp); + ModelResp genreModelResp = addModel_1(s2Domain, demoDatabaseResp); + ModelResp artistModelResp = addModel_2(s2Domain, demoDatabaseResp); + ModelResp filesModelResp = addModel_3(s2Domain, demoDatabaseResp); + ModelResp songModelResp = addModel_4(s2Domain, demoDatabaseResp); addDataSet_1(s2Domain); addModelRela_1(s2Domain, genreModelResp, artistModelResp); addModelRela_2(s2Domain, filesModelResp, artistModelResp); @@ -79,7 +57,6 @@ public class BenchMarkDemoDataLoader { } catch (Exception e) { log.error("Failed to add bench mark demo data", e); } - } public DomainResp addDomain() { @@ -222,17 +199,17 @@ public class BenchMarkDemoDataLoader { } public void addDataSet_1(DomainResp s2Domain) { - DataSetReq viewReq = new DataSetReq(); - viewReq.setName("cspider"); - viewReq.setBizName("singer"); - viewReq.setDomainId(s2Domain.getId()); - viewReq.setDescription("包含cspider数据集相关标签和指标信息"); - viewReq.setAdmins(Lists.newArrayList("admin")); - List viewModelConfigs = modelDemoDataLoader.getDataSetModelConfigs(s2Domain.getId()); - DataSetDetail viewDetail = new DataSetDetail(); - viewDetail.setDataSetModelConfigs(viewModelConfigs); - viewReq.setDataSetDetail(viewDetail); - viewReq.setTypeEnum(TypeEnums.DATASET); + DataSetReq dataSetReq = new DataSetReq(); + dataSetReq.setName("cspider"); + dataSetReq.setBizName("singer"); + dataSetReq.setDomainId(s2Domain.getId()); + dataSetReq.setDescription("包含cspider数据集相关标签和指标信息"); + dataSetReq.setAdmins(Lists.newArrayList("admin")); + List viewModelConfigs = getDataSetModelConfigs(s2Domain.getId()); + DataSetDetail dsDetail = new DataSetDetail(); + dsDetail.setDataSetModelConfigs(viewModelConfigs); + dataSetReq.setDataSetDetail(dsDetail); + dataSetReq.setTypeEnum(TypeEnums.DATASET); QueryConfig queryConfig = new QueryConfig(); TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig(); TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig(); @@ -250,8 +227,8 @@ public class BenchMarkDemoDataLoader { metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig); queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig); queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig); - viewReq.setQueryConfig(queryConfig); - viewService.save(viewReq, User.getFakeUser()); + dataSetReq.setQueryConfig(queryConfig); + dataSetService.save(dataSetReq, User.getFakeUser()); } public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp, ModelResp artistModelResp) { diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/DuSQLDemoDataLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java similarity index 88% rename from launchers/standalone/src/main/java/com/tencent/supersonic/DuSQLDemoDataLoader.java rename to launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java index ed13483fa..5335737a5 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/DuSQLDemoDataLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java @@ -1,13 +1,20 @@ -package com.tencent.supersonic; +package com.tencent.supersonic.demo; +import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.server.agent.Agent; +import com.tencent.supersonic.chat.server.agent.AgentConfig; +import com.tencent.supersonic.chat.server.agent.AgentToolType; +import com.tencent.supersonic.chat.server.agent.LLMParserTool; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.pojo.enums.TimeMode; import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.demo.S2BaseDemo; import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; import com.tencent.supersonic.headless.api.pojo.Dim; @@ -25,14 +32,8 @@ import com.tencent.supersonic.headless.api.pojo.request.MetricReq; import com.tencent.supersonic.headless.api.pojo.request.ModelReq; import com.tencent.supersonic.headless.api.pojo.request.DataSetReq; import com.tencent.supersonic.headless.api.pojo.response.MetricResp; -import com.tencent.supersonic.headless.server.service.DomainService; -import com.tencent.supersonic.headless.server.service.MetricService; -import com.tencent.supersonic.headless.server.service.ModelRelaService; -import com.tencent.supersonic.headless.server.service.ModelService; -import com.tencent.supersonic.headless.server.service.DataSetService; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; -import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Component; import java.util.ArrayList; @@ -42,21 +43,7 @@ import java.util.List; @Component @Slf4j -public class DuSQLDemoDataLoader { - - private User user = User.getFakeUser(); - - @Autowired - private DomainService domainService; - @Autowired - private ModelService modelService; - @Autowired - private ModelRelaService modelRelaService; - @Autowired - private MetricService metricService; - - @Autowired - private DataSetService viewService; +public class DuSQLDemo extends S2BaseDemo { public void doRun() { try { @@ -70,6 +57,7 @@ public class DuSQLDemoDataLoader { addModelRela_2(); addModelRela_3(); addModelRela_4(); + addAgent(); } catch (Exception e) { log.error("Failed to add bench mark demo data", e); } @@ -255,22 +243,22 @@ public class DuSQLDemoDataLoader { } public void addDataSet_1() { - DataSetReq viewReq = new DataSetReq(); - viewReq.setName("DuSQL 互联网企业"); - viewReq.setBizName("internet"); - viewReq.setDomainId(4L); - viewReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等"); - viewReq.setAdmins(Lists.newArrayList("admin")); + DataSetReq dataSetReq = new DataSetReq(); + dataSetReq.setName("DuSQL 互联网企业"); + dataSetReq.setBizName("internet"); + dataSetReq.setDomainId(4L); + dataSetReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等"); + dataSetReq.setAdmins(Lists.newArrayList("admin")); List viewModelConfigs = Lists.newArrayList( new DataSetModelConfig(9L, Lists.newArrayList(16L, 17L, 18L, 19L, 20L), Lists.newArrayList(10L, 11L)), new DataSetModelConfig(10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)), new DataSetModelConfig(11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)), new DataSetModelConfig(12L, Lists.newArrayList(24L), Lists.newArrayList(16L, 17L, 18L, 19L))); - DataSetDetail viewDetail = new DataSetDetail(); - viewDetail.setDataSetModelConfigs(viewModelConfigs); - viewReq.setDataSetDetail(viewDetail); - viewReq.setTypeEnum(TypeEnums.DATASET); + DataSetDetail dsDetail = new DataSetDetail(); + dsDetail.setDataSetModelConfigs(viewModelConfigs); + dataSetReq.setDataSetDetail(dsDetail); + dataSetReq.setTypeEnum(TypeEnums.DATASET); QueryConfig queryConfig = new QueryConfig(); MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig(); TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); @@ -278,8 +266,8 @@ public class DuSQLDemoDataLoader { timeDefaultConfig.setUnit(1); metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig); queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig); - viewReq.setQueryConfig(queryConfig); - viewService.save(viewReq, User.getFakeUser()); + dataSetReq.setQueryConfig(queryConfig); + dataSetService.save(dataSetReq, User.getFakeUser()); } public void addModelRela_1() { @@ -330,4 +318,24 @@ public class DuSQLDemoDataLoader { modelRelaService.save(modelRelaReq, user); } + private void addAgent() { + Agent agent = new Agent(); + agent.setName("DuSQL 互联网企业"); + agent.setDescription("DuSQL"); + agent.setStatus(1); + agent.setEnableSearch(1); + agent.setExamples(Lists.newArrayList()); + AgentConfig agentConfig = new AgentConfig(); + + LLMParserTool llmParserTool = new LLMParserTool(); + llmParserTool.setId("1"); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); + llmParserTool.setDataSetIds(Lists.newArrayList(4L)); + agentConfig.getTools().add(llmParserTool); + + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + log.info("agent:{}", JsonUtil.toString(agent)); + agentService.createAgent(agent, User.getFakeUser()); + } + } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java new file mode 100644 index 000000000..ba1f80fe4 --- /dev/null +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java @@ -0,0 +1,197 @@ +package com.tencent.supersonic.demo; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.server.agent.*; +import com.tencent.supersonic.common.pojo.enums.StatusEnum; +import com.tencent.supersonic.common.pojo.enums.TimeMode; +import com.tencent.supersonic.common.pojo.enums.TypeEnums; +import com.tencent.supersonic.headless.api.pojo.DataSetDetail; +import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; +import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo; +import com.tencent.supersonic.headless.api.pojo.Dim; +import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams; +import com.tencent.supersonic.headless.api.pojo.Identify; +import com.tencent.supersonic.headless.api.pojo.Measure; +import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig; +import com.tencent.supersonic.headless.api.pojo.ModelDetail; +import com.tencent.supersonic.headless.api.pojo.QueryConfig; +import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; +import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; +import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; +import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; +import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; +import com.tencent.supersonic.headless.api.pojo.request.DataSetReq; +import com.tencent.supersonic.headless.api.pojo.request.DomainReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelReq; +import com.tencent.supersonic.headless.api.pojo.request.TagObjectReq; +import com.tencent.supersonic.headless.api.pojo.response.*; +import lombok.extern.slf4j.Slf4j; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@Component +@Slf4j +@Order(2) +public class S2ArtistDemo extends S2BaseDemo { + + public void doRun() { + try { + DomainResp singerDomain = addDomain(); + TagObjectResp singerTagObject = addTagObjectSinger(singerDomain); + ModelResp singerModel = addModel(singerDomain, demoDatabaseResp, singerTagObject); + addTags(singerModel); + long dataSetId = addDataSet(singerDomain, singerModel); + addAgent(dataSetId); + } catch (Exception e) { + log.error("Failed to add model demo data", e); + } + + } + + private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception { + TagObjectReq tagObjectReq = new TagObjectReq(); + tagObjectReq.setDomainId(singerDomain.getId()); + tagObjectReq.setName("艺人"); + tagObjectReq.setBizName("singer"); + User user = User.getFakeUser(); + return tagObjectService.create(tagObjectReq, user); + } + + public DomainResp addDomain() { + DomainReq domainReq = new DomainReq(); + domainReq.setName("艺人库"); + domainReq.setBizName("supersonic"); + domainReq.setParentId(0L); + domainReq.setStatus(StatusEnum.ONLINE.getCode()); + domainReq.setViewers(Arrays.asList("admin", "tom", "jack")); + domainReq.setViewOrgs(Collections.singletonList("1")); + domainReq.setAdmins(Arrays.asList("admin", "alice")); + domainReq.setAdminOrgs(Collections.emptyList()); + return domainService.createDomain(domainReq, user); + } + + public ModelResp addModel(DomainResp singerDomain, + DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception { + ModelReq modelReq = new ModelReq(); + modelReq.setName("艺人库"); + modelReq.setBizName("singer"); + modelReq.setDescription("艺人库"); + modelReq.setDatabaseId(s2Database.getId()); + modelReq.setDomainId(singerDomain.getId()); + modelReq.setTagObjectId(singerTagObject.getId()); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("1")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + ModelDetail modelDetail = new ModelDetail(); + List identifiers = new ArrayList<>(); + Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1); + identify.setEntityNames(Lists.newArrayList("歌手", "艺人")); + identifiers.add(identify); + modelDetail.setIdentifiers(identifiers); + + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); + dimension1.setTypeParams(new DimensionTimeTypeParams()); + dimensions.add(dimension1); + dimensions.add(new Dim("活跃区域", "act_area", + DimensionType.categorical.name(), 1, 1)); + dimensions.add(new Dim("代表作", "song_name", + DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("风格", "genre", + DimensionType.categorical.name(), 1, 1)); + modelDetail.setDimensions(dimensions); + + Measure measure1 = new Measure("播放量", "js_play_cnt", "sum", 1); + Measure measure2 = new Measure("下载量", "down_cnt", "sum", 1); + Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1); + modelDetail.setMeasures(Lists.newArrayList(measure1, measure2, measure3)); + modelDetail.setQueryType("sql_query"); + modelDetail.setSqlQuery("select imp_date, singer_name, act_area, song_name, genre, " + + "js_play_cnt, down_cnt, favor_cnt from singer"); + modelReq.setModelDetail(modelDetail); + return modelService.createModel(modelReq, user); + } + + private void addTags(ModelResp model) { + addTag(dimensionService.getDimension("act_area", model.getId()).getId(), + TagDefineType.DIMENSION); + addTag(dimensionService.getDimension("song_name", model.getId()).getId(), + TagDefineType.DIMENSION); + addTag(dimensionService.getDimension("genre", model.getId()).getId(), + TagDefineType.DIMENSION); + addTag(dimensionService.getDimension("singer_name", model.getId()).getId(), + TagDefineType.DIMENSION); + addTag(metricService.getMetric(model.getId(), "js_play_cnt").getId(), + TagDefineType.METRIC); + } + + public long addDataSet(DomainResp singerDomain, ModelResp singerModel) { + DataSetReq dataSetReq = new DataSetReq(); + dataSetReq.setName("艺人库"); + dataSetReq.setBizName("singer"); + dataSetReq.setDomainId(singerDomain.getId()); + dataSetReq.setDescription("包含艺人相关标签和指标信息"); + dataSetReq.setAdmins(Lists.newArrayList("admin", "jack")); + List dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId()); + DataSetDetail dataSetDetail = new DataSetDetail(); + dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs); + dataSetReq.setDataSetDetail(dataSetDetail); + dataSetReq.setTypeEnum(TypeEnums.DATASET); + QueryConfig queryConfig = new QueryConfig(); + TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig(); + TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig(); + tagTimeDefaultConfig.setTimeMode(TimeMode.LAST); + tagTimeDefaultConfig.setUnit(7); + tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig); + DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo(); + defaultDisplayInfo.setDimensionIds(dataSetModelConfigs.get(0).getDimensions()); + MetricResp jsPlayCntMetric = getMetric("js_play_cnt", singerModel); + defaultDisplayInfo.setMetricIds(Lists.newArrayList(jsPlayCntMetric.getId())); + tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo); + MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig(); + TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); + timeDefaultConfig.setTimeMode(TimeMode.RECENT); + timeDefaultConfig.setUnit(7); + metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig); + queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig); + queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig); + dataSetReq.setQueryConfig(queryConfig); + DataSetResp dataSetResp = dataSetService.save(dataSetReq, User.getFakeUser()); + return dataSetResp.getId(); + } + + private void addAgent(long dataSetId) { + Agent agent = new Agent(); + agent.setName("做圈选"); + agent.setDescription("帮助您用自然语言进行圈选,支持多条件组合筛选"); + agent.setStatus(1); + agent.setEnableSearch(1); + agent.setExamples(Lists.newArrayList("国风风格艺人", "港台地区的艺人", "风格为流行的艺人")); + AgentConfig agentConfig = new AgentConfig(); + RuleParserTool ruleQueryTool = new RuleParserTool(); + ruleQueryTool.setId("0"); + ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); + ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); + agentConfig.getTools().add(ruleQueryTool); + + if (demoEnableLlm) { + LLMParserTool llmParserTool = new LLMParserTool(); + llmParserTool.setId("1"); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); + llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId)); + agentConfig.getTools().add(llmParserTool); + } + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + int id = agentService.createAgent(agent, User.getFakeUser()); + agent.setId(id); + } + +} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java new file mode 100644 index 000000000..6f7581108 --- /dev/null +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java @@ -0,0 +1,176 @@ +package com.tencent.supersonic.demo; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authorization.service.AuthService; +import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq; +import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq; +import com.tencent.supersonic.chat.server.service.AgentService; +import com.tencent.supersonic.chat.server.service.ChatManageService; +import com.tencent.supersonic.chat.server.service.ChatService; +import com.tencent.supersonic.chat.server.service.PluginService; +import com.tencent.supersonic.common.service.SysParameterService; +import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; +import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; +import com.tencent.supersonic.headless.api.pojo.RelateDimension; +import com.tencent.supersonic.headless.api.pojo.enums.DataType; +import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; +import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; +import com.tencent.supersonic.headless.api.pojo.request.TagReq; +import com.tencent.supersonic.headless.api.pojo.response.*; +import com.tencent.supersonic.headless.server.pojo.MetaFilter; +import com.tencent.supersonic.headless.server.service.*; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.CommandLineRunner; +import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.util.CollectionUtils; + +import java.util.List; +import java.util.stream.Collectors; + +@Slf4j +public abstract class S2BaseDemo implements CommandLineRunner { + protected DatabaseResp demoDatabaseResp; + + protected User user = User.getFakeUser(); + @Autowired + protected DatabaseService databaseService; + @Autowired + protected DomainService domainService; + @Autowired + protected ModelService modelService; + @Autowired + protected ModelRelaService modelRelaService; + @Autowired + protected DimensionService dimensionService; + @Autowired + protected MetricService metricService; + @Autowired + protected TagMetaService tagMetaService; + @Autowired + protected AuthService authService; + @Autowired + protected DataSetService dataSetService; + @Autowired + protected TermService termService; + @Autowired + protected PluginService pluginService; + @Autowired + protected DataSourceProperties dataSourceProperties; + @Autowired + protected TagObjectService tagObjectService; + @Autowired + protected ChatService chatService; + @Autowired + protected ChatManageService chatManageService; + @Autowired + protected AgentService agentService; + @Autowired + protected SysParameterService sysParameterService; + @Value("${s2.demo.names:S2VisitsDemo}") + protected List demoList; + @Value("${s2.demo.enableLLM:true}") + protected boolean demoEnableLlm; + + public void run(String... args) { + demoDatabaseResp = addDatabase(); + if (demoList != null && demoList.contains(getClass().getSimpleName())) { + doRun(); + } + } + + abstract void doRun(); + + protected DatabaseResp addDatabase() { + String url = dataSourceProperties.getUrl(); + DatabaseReq databaseReq = new DatabaseReq(); + databaseReq.setName("数据实例"); + databaseReq.setDescription("样例数据库实例"); + if (StringUtils.isNotBlank(url) + && url.toLowerCase().contains(DataType.MYSQL.getFeature().toLowerCase())) { + databaseReq.setType(DataType.MYSQL.getFeature()); + databaseReq.setVersion("5.7"); + } else { + databaseReq.setType(DataType.H2.getFeature()); + } + databaseReq.setUrl(url); + databaseReq.setUsername(dataSourceProperties.getUsername()); + databaseReq.setPassword(dataSourceProperties.getPassword()); + return databaseService.createOrUpdateDatabase(databaseReq, user); + } + + protected MetricResp getMetric(String bizName, ModelResp model) { + return metricService.getMetric(model.getId(), bizName); + } + + protected List getDataSetModelConfigs(Long domainId) { + List dataSetModelConfigs = Lists.newArrayList(); + List modelByDomainIds = + modelService.getModelByDomainIds(Lists.newArrayList(domainId)); + + for (ModelResp modelResp : modelByDomainIds) { + DataSetModelConfig dataSetModelConfig = new DataSetModelConfig(); + dataSetModelConfig.setId(modelResp.getId()); + MetaFilter metaFilter = new MetaFilter(); + metaFilter.setModelIds(Lists.newArrayList(modelResp.getId())); + List metrics = metricService.getMetrics(metaFilter) + .stream().map(MetricResp::getId).collect(Collectors.toList()); + dataSetModelConfig.setMetrics(metrics); + List dimensions = dimensionService.getDimensions(metaFilter) + .stream().map(DimensionResp::getId).collect(Collectors.toList()); + dataSetModelConfig.setMetrics(metrics); + dataSetModelConfig.setDimensions(dimensions); + dataSetModelConfigs.add(dataSetModelConfig); + } + return dataSetModelConfigs; + } + + + protected void parseAndExecute(int chatId, int agentId, String queryText) throws Exception { + ChatParseReq chatParseReq = new ChatParseReq(); + chatParseReq.setQueryText(queryText); + chatParseReq.setChatId(chatId); + chatParseReq.setAgentId(agentId); + chatParseReq.setUser(User.getFakeUser()); + ParseResp parseResp = chatService.performParsing(chatParseReq); + if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) { + log.info("parseResp.getSelectedParses() is empty"); + return; + } + ChatExecuteReq executeReq = new ChatExecuteReq(); + executeReq.setQueryId(parseResp.getQueryId()); + executeReq.setParseId(parseResp.getSelectedParses().get(0).getId()); + executeReq.setQueryText(queryText); + executeReq.setChatId(parseResp.getChatId()); + executeReq.setUser(User.getFakeUser()); + executeReq.setSaveAnswer(true); + chatService.performExecution(executeReq); + } + + protected void addTag(Long itemId, TagDefineType tagDefineType) { + TagReq tagReq = new TagReq(); + tagReq.setTagDefineType(tagDefineType); + tagReq.setItemId(itemId); + tagMetaService.create(tagReq, User.getFakeUser()); + } + + protected DimensionResp getDimension(String bizName, ModelResp model) { + return dimensionService.getDimension(bizName, model.getId()); + } + + protected RelateDimension getRelateDimension(List dimensionIds) { + RelateDimension relateDimension = new RelateDimension(); + for (Long id : dimensionIds) { + relateDimension.getDrillDownDimensions().add(new DrillDownDimension(id)); + } + return relateDimension; + } + + protected void updateQueryScore(Integer queryId) { + chatManageService.updateFeedback(queryId, 5, ""); + } + +} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ModelDemoDataLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java similarity index 58% rename from launchers/standalone/src/main/java/com/tencent/supersonic/ModelDemoDataLoader.java rename to launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 4899d05be..83eab2573 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ModelDemoDataLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -1,196 +1,130 @@ -package com.tencent.supersonic; +package com.tencent.supersonic.demo; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; -import com.tencent.supersonic.auth.api.authorization.service.AuthService; +import com.tencent.supersonic.chat.server.agent.*; import com.tencent.supersonic.chat.server.plugin.Plugin; import com.tencent.supersonic.chat.server.plugin.PluginParseConfig; import com.tencent.supersonic.chat.server.plugin.build.ParamOption; import com.tencent.supersonic.chat.server.plugin.build.WebBase; -import com.tencent.supersonic.chat.server.service.PluginService; import com.tencent.supersonic.common.pojo.JoinCondition; import com.tencent.supersonic.common.pojo.ModelRela; -import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; -import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; -import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; -import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; -import com.tencent.supersonic.common.pojo.enums.StatusEnum; -import com.tencent.supersonic.common.pojo.enums.TimeMode; -import com.tencent.supersonic.common.pojo.enums.TypeEnums; +import com.tencent.supersonic.common.pojo.SysParameter; +import com.tencent.supersonic.common.pojo.enums.*; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.headless.api.pojo.DataSetDetail; -import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig; -import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo; -import com.tencent.supersonic.headless.api.pojo.Dim; -import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams; -import com.tencent.supersonic.headless.api.pojo.DrillDownDimension; -import com.tencent.supersonic.headless.api.pojo.Field; -import com.tencent.supersonic.headless.api.pojo.FieldParam; -import com.tencent.supersonic.headless.api.pojo.Identify; -import com.tencent.supersonic.headless.api.pojo.Measure; -import com.tencent.supersonic.headless.api.pojo.MeasureParam; -import com.tencent.supersonic.headless.api.pojo.MetricDefineByFieldParams; -import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams; -import com.tencent.supersonic.headless.api.pojo.MetricDefineByMetricParams; -import com.tencent.supersonic.headless.api.pojo.MetricParam; -import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig; -import com.tencent.supersonic.headless.api.pojo.ModelDetail; -import com.tencent.supersonic.headless.api.pojo.QueryConfig; -import com.tencent.supersonic.headless.api.pojo.RelateDimension; -import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig; -import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig; -import com.tencent.supersonic.headless.api.pojo.enums.DataType; -import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; -import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; -import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType; -import com.tencent.supersonic.headless.api.pojo.enums.SemanticType; -import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType; -import com.tencent.supersonic.headless.api.pojo.request.DataSetReq; -import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq; -import com.tencent.supersonic.headless.api.pojo.request.DimensionReq; -import com.tencent.supersonic.headless.api.pojo.request.DomainReq; -import com.tencent.supersonic.headless.api.pojo.request.MetricReq; -import com.tencent.supersonic.headless.api.pojo.request.ModelReq; -import com.tencent.supersonic.headless.api.pojo.request.TagObjectReq; -import com.tencent.supersonic.headless.api.pojo.request.TagReq; -import com.tencent.supersonic.headless.api.pojo.request.TermReq; -import com.tencent.supersonic.headless.api.pojo.response.DataSetResp; -import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp; -import com.tencent.supersonic.headless.api.pojo.response.DimensionResp; -import com.tencent.supersonic.headless.api.pojo.response.DomainResp; -import com.tencent.supersonic.headless.api.pojo.response.MetricResp; -import com.tencent.supersonic.headless.api.pojo.response.ModelResp; -import com.tencent.supersonic.headless.api.pojo.response.TagObjectResp; -import com.tencent.supersonic.headless.server.pojo.MetaFilter; -import com.tencent.supersonic.headless.server.service.DataSetService; -import com.tencent.supersonic.headless.server.service.DatabaseService; -import com.tencent.supersonic.headless.server.service.DimensionService; -import com.tencent.supersonic.headless.server.service.DomainService; -import com.tencent.supersonic.headless.server.service.MetricService; -import com.tencent.supersonic.headless.server.service.ModelRelaService; -import com.tencent.supersonic.headless.server.service.ModelService; -import com.tencent.supersonic.headless.server.service.TagMetaService; -import com.tencent.supersonic.headless.server.service.TagObjectService; -import com.tencent.supersonic.headless.server.service.TermService; +import com.tencent.supersonic.headless.api.pojo.*; +import com.tencent.supersonic.headless.api.pojo.enums.*; +import com.tencent.supersonic.headless.api.pojo.request.*; +import com.tencent.supersonic.headless.api.pojo.response.*; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.StringUtils; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.autoconfigure.jdbc.DataSourceProperties; +import org.springframework.core.annotation.Order; import org.springframework.stereotype.Component; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.stream.Collectors; @Component @Slf4j -public class ModelDemoDataLoader { - - protected DatabaseResp tmpDatabaseResp = null; - private User user = User.getFakeUser(); - @Autowired - private DatabaseService databaseService; - @Autowired - private DomainService domainService; - @Autowired - private ModelService modelService; - @Autowired - private ModelRelaService modelRelaService; - @Autowired - private DimensionService dimensionService; - @Autowired - private MetricService metricService; - @Autowired - private AuthService authService; - @Autowired - private DataSetService dataSetService; - @Autowired - private DataSourceProperties dataSourceProperties; - @Autowired - private TagObjectService tagObjectService; - @Autowired - private TagMetaService tagMetaService; - @Autowired - private TermService termService; - @Autowired - private PluginService pluginService; +@Order(1) +public class S2VisitsDemo extends S2BaseDemo { public void doRun() { try { - DatabaseResp databaseResp = addDatabase(); - tmpDatabaseResp = databaseResp; + // create domain DomainResp s2Domain = addDomain(); TagObjectResp s2TagObject = addTagObjectUser(s2Domain); - ModelResp userModel = addModel_1(s2Domain, databaseResp, s2TagObject); - ModelResp pvUvModel = addModel_2(s2Domain, databaseResp); - DimensionResp userDimension = getDimension("user_name", userModel); + + // create models + ModelResp userModel = addModel_1(s2Domain, demoDatabaseResp, s2TagObject); + ModelResp pvUvModel = addModel_2(s2Domain, demoDatabaseResp); + ModelResp stayTimeModel = addModel_3(s2Domain, demoDatabaseResp); + addModelRela_1(s2Domain, userModel, pvUvModel); + addModelRela_2(s2Domain, userModel, stayTimeModel); + addTags(userModel); + + //create metrics and dimensions DimensionResp departmentDimension = getDimension("department", userModel); MetricResp metricUv = addMetric_uv(userModel, departmentDimension); MetricResp metricPv = getMetric("pv", pvUvModel); addMetric_pv_avg(metricPv, metricUv, departmentDimension, pvUvModel); - ModelResp stayTimeModel = addModel_3(s2Domain, databaseResp); - addModelRela_1(s2Domain, userModel, pvUvModel); - addModelRela_2(s2Domain, userModel, stayTimeModel); - DomainResp singerDomain = addDomain_2(); - TagObjectResp singerTagObject = addTagObjectSinger(singerDomain); - ModelResp singerModel = addModel_4(singerDomain, databaseResp, singerTagObject); + DimensionResp pageDimension = getDimension("page", stayTimeModel); updateDimension(stayTimeModel, pageDimension); + DimensionResp userDimension = getDimension("user_name", userModel); updateMetric(stayTimeModel, departmentDimension, userDimension); - addTags(userModel, singerModel); updateMetric_pv(pvUvModel, departmentDimension, userDimension, metricPv); - DataSetResp s2DataSet = addDataSet_1(s2Domain); - addDataSet_2(singerDomain, singerModel); + + //create data set + DataSetResp s2DataSet = addDataSet(s2Domain); addAuthGroup_1(stayTimeModel); addAuthGroup_2(stayTimeModel); + + //create terms and plugin addTerm(s2Domain); addTerm_1(s2Domain); - addPlugin_1(s2DataSet, userDimension, userModel); + addPlugin(s2DataSet, userDimension, userModel); + addSysParameter(); + + //create agent + Integer agentId = addAgent(s2DataSet.getId()); + addSampleChats(agentId); + updateQueryScore(1); + updateQueryScore(4); } catch (Exception e) { - log.error("Failed to add model demo data", e); + log.error("Failed to add S2Visits demo data", e); } - } - private TagObjectResp addTagObjectUser(DomainResp s2Domain) throws Exception { - TagObjectReq tagObjectReq = new TagObjectReq(); - tagObjectReq.setDomainId(s2Domain.getId()); - tagObjectReq.setName("用户"); - tagObjectReq.setBizName("user"); - User user = User.getFakeUser(); - return tagObjectService.create(tagObjectReq, user); + public void addSampleChats(Integer agentId) throws Exception { + Long chatId = chatManageService.addChat(user, "样例对话1", agentId); + + parseAndExecute(chatId.intValue(), agentId, "超音数 访问次数"); + parseAndExecute(chatId.intValue(), agentId, "按部门统计"); + parseAndExecute(chatId.intValue(), agentId, "查询近30天"); + parseAndExecute(chatId.intValue(), agentId, "alice 停留时长"); + parseAndExecute(chatId.intValue(), agentId, "对比alice和lucy的访问次数"); + parseAndExecute(chatId.intValue(), agentId, "访问次数最高的部门"); } - private TagObjectResp addTagObjectSinger(DomainResp singerDomain) throws Exception { - TagObjectReq tagObjectReq = new TagObjectReq(); - tagObjectReq.setDomainId(singerDomain.getId()); - tagObjectReq.setName("艺人"); - tagObjectReq.setBizName("singer"); - User user = User.getFakeUser(); - return tagObjectService.create(tagObjectReq, user); + public void addSysParameter() { + SysParameter sysParameter = new SysParameter(); + sysParameter.setId(1); + sysParameter.init(); + sysParameterService.save(sysParameter); } - public DatabaseResp addDatabase() { - String url = dataSourceProperties.getUrl(); - DatabaseReq databaseReq = new DatabaseReq(); - databaseReq.setName("数据实例"); - databaseReq.setDescription("样例数据库实例"); - if (StringUtils.isNotBlank(url) - && url.toLowerCase().contains(DataType.MYSQL.getFeature().toLowerCase())) { - databaseReq.setType(DataType.MYSQL.getFeature()); - databaseReq.setVersion("5.7"); - } else { - databaseReq.setType(DataType.H2.getFeature()); + private Integer addAgent(long dataSetId) { + Agent agent = new Agent(); + agent.setName("算指标"); + agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计"); + agent.setStatus(1); + agent.setEnableSearch(1); + agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数", + "对比alice和lucy的停留时长", "超音数访问次数最高的部门")); + AgentConfig agentConfig = new AgentConfig(); + RuleParserTool ruleQueryTool = new RuleParserTool(); + ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); + ruleQueryTool.setId("0"); + ruleQueryTool.setDataSetIds(Lists.newArrayList(dataSetId)); + agentConfig.getTools().add(ruleQueryTool); + if (demoEnableLlm) { + LLMParserTool llmParserTool = new LLMParserTool(); + llmParserTool.setId("1"); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); + llmParserTool.setDataSetIds(Lists.newArrayList(dataSetId)); + agentConfig.getTools().add(llmParserTool); } - databaseReq.setUrl(url); - databaseReq.setUsername(dataSourceProperties.getUsername()); - databaseReq.setPassword(dataSourceProperties.getPassword()); - return databaseService.createOrUpdateDatabase(databaseReq, user); + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + MultiTurnConfig multiTurnConfig = new MultiTurnConfig(false); + agent.setMultiTurnConfig(multiTurnConfig); + int id = agentService.createAgent(agent, User.getFakeUser()); + agent.setId(id); + return agent.getId(); } public DomainResp addDomain() { @@ -348,82 +282,9 @@ public class ModelDemoDataLoader { modelRelaService.save(modelRelaReq, user); } - public DomainResp addDomain_2() { - DomainReq domainReq = new DomainReq(); - domainReq.setName("艺人库"); - domainReq.setBizName("supersonic"); - domainReq.setParentId(0L); - domainReq.setStatus(StatusEnum.ONLINE.getCode()); - domainReq.setViewers(Arrays.asList("admin", "tom", "jack")); - domainReq.setViewOrgs(Collections.singletonList("1")); - domainReq.setAdmins(Arrays.asList("admin", "alice")); - domainReq.setAdminOrgs(Collections.emptyList()); - return domainService.createDomain(domainReq, user); - } - - public ModelResp addModel_4(DomainResp singerDomain, - DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception { - ModelReq modelReq = new ModelReq(); - modelReq.setName("艺人库"); - modelReq.setBizName("singer"); - modelReq.setDescription("艺人库"); - modelReq.setDatabaseId(s2Database.getId()); - modelReq.setDomainId(singerDomain.getId()); - modelReq.setTagObjectId(singerTagObject.getId()); - modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); - modelReq.setViewOrgs(Collections.singletonList("1")); - modelReq.setAdmins(Collections.singletonList("admin")); - modelReq.setAdminOrgs(Collections.emptyList()); - ModelDetail modelDetail = new ModelDetail(); - List identifiers = new ArrayList<>(); - Identify identify = new Identify("歌手名", IdentifyType.primary.name(), "singer_name", 1); - identify.setEntityNames(Lists.newArrayList("歌手", "艺人")); - identifiers.add(identify); - modelDetail.setIdentifiers(identifiers); - - List dimensions = new ArrayList<>(); - Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); - dimension1.setTypeParams(new DimensionTimeTypeParams()); - dimensions.add(dimension1); - dimensions.add(new Dim("活跃区域", "act_area", - DimensionType.categorical.name(), 1, 1)); - dimensions.add(new Dim("代表作", "song_name", - DimensionType.categorical.name(), 1)); - dimensions.add(new Dim("风格", "genre", - DimensionType.categorical.name(), 1, 1)); - modelDetail.setDimensions(dimensions); - - Measure measure1 = new Measure("播放量", "js_play_cnt", "sum", 1); - Measure measure2 = new Measure("下载量", "down_cnt", "sum", 1); - Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1); - modelDetail.setMeasures(Lists.newArrayList(measure1, measure2, measure3)); - modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery("select imp_date, singer_name, act_area, song_name, genre, " - + "js_play_cnt, down_cnt, favor_cnt from singer"); - modelReq.setModelDetail(modelDetail); - return modelService.createModel(modelReq, user); - } - - private void addTags(ModelResp userModel, ModelResp singerModel) { - addTag(dimensionService.getDimension("department", userModel.getId()).getId(), + private void addTags(ModelResp model) { + addTag(dimensionService.getDimension("department", model.getId()).getId(), TagDefineType.DIMENSION); - addTag(dimensionService.getDimension("act_area", singerModel.getId()).getId(), - TagDefineType.DIMENSION); - addTag(dimensionService.getDimension("song_name", singerModel.getId()).getId(), - TagDefineType.DIMENSION); - addTag(dimensionService.getDimension("genre", singerModel.getId()).getId(), - TagDefineType.DIMENSION); - addTag(dimensionService.getDimension("singer_name", singerModel.getId()).getId(), - TagDefineType.DIMENSION); - addTag(metricService.getMetric(singerModel.getId(), "js_play_cnt").getId(), - TagDefineType.METRIC); - } - - private void addTag(Long itemId, TagDefineType tagDefineType) { - TagReq tagReq = new TagReq(); - tagReq.setTagDefineType(tagDefineType); - tagReq.setItemId(itemId); - tagMetaService.create(tagReq, User.getFakeUser()); } public void updateDimension(ModelResp stayTimeModel, DimensionResp pageDimension) throws Exception { @@ -535,7 +396,7 @@ public class ModelDemoDataLoader { return metricService.createMetric(metricReq, user); } - public DataSetResp addDataSet_1(DomainResp s2Domain) { + public DataSetResp addDataSet(DomainResp s2Domain) { DataSetReq dataSetReq = new DataSetReq(); dataSetReq.setName("超音数"); dataSetReq.setBizName("s2"); @@ -558,39 +419,6 @@ public class ModelDemoDataLoader { return dataSetService.save(dataSetReq, User.getFakeUser()); } - public void addDataSet_2(DomainResp singerDomain, ModelResp singerModel) { - DataSetReq dataSetReq = new DataSetReq(); - dataSetReq.setName("艺人库"); - dataSetReq.setBizName("singer"); - dataSetReq.setDomainId(singerDomain.getId()); - dataSetReq.setDescription("包含艺人相关标签和指标信息"); - dataSetReq.setAdmins(Lists.newArrayList("admin", "jack")); - List dataSetModelConfigs = getDataSetModelConfigs(singerDomain.getId()); - DataSetDetail dataSetDetail = new DataSetDetail(); - dataSetDetail.setDataSetModelConfigs(dataSetModelConfigs); - dataSetReq.setDataSetDetail(dataSetDetail); - dataSetReq.setTypeEnum(TypeEnums.DATASET); - QueryConfig queryConfig = new QueryConfig(); - TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig(); - TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig(); - tagTimeDefaultConfig.setTimeMode(TimeMode.LAST); - tagTimeDefaultConfig.setUnit(7); - tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig); - DefaultDisplayInfo defaultDisplayInfo = new DefaultDisplayInfo(); - defaultDisplayInfo.setDimensionIds(dataSetModelConfigs.get(0).getDimensions()); - MetricResp jsPlayCntMetric = getMetric("js_play_cnt", singerModel); - defaultDisplayInfo.setMetricIds(Lists.newArrayList(jsPlayCntMetric.getId())); - tagTypeDefaultConfig.setDefaultDisplayInfo(defaultDisplayInfo); - MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig(); - TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig(); - timeDefaultConfig.setTimeMode(TimeMode.RECENT); - timeDefaultConfig.setUnit(7); - metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig); - queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig); - queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig); - dataSetReq.setQueryConfig(queryConfig); - dataSetService.save(dataSetReq, User.getFakeUser()); - } public void addTerm(DomainResp s2Domain) { TermReq termReq = new TermReq(); @@ -641,8 +469,8 @@ public class ModelDemoDataLoader { authService.addOrUpdateAuthGroup(authGroupReq); } - private void addPlugin_1(DataSetResp s2DataSet, DimensionResp userDimension, - ModelResp userModel) { + private void addPlugin(DataSetResp s2DataSet, DimensionResp userDimension, + ModelResp userModel) { Plugin plugin1 = new Plugin(); plugin1.setType("WEB_PAGE"); plugin1.setDataSetList(Arrays.asList(s2DataSet.getId())); @@ -666,42 +494,13 @@ public class ModelDemoDataLoader { pluginService.createPlugin(plugin1, user); } - private RelateDimension getRelateDimension(List dimensionIds) { - RelateDimension relateDimension = new RelateDimension(); - for (Long id : dimensionIds) { - relateDimension.getDrillDownDimensions().add(new DrillDownDimension(id)); - } - return relateDimension; - } - - private DimensionResp getDimension(String bizName, ModelResp model) { - return dimensionService.getDimension(bizName, model.getId()); - } - - private MetricResp getMetric(String bizName, ModelResp model) { - return metricService.getMetric(model.getId(), bizName); - } - - protected List getDataSetModelConfigs(Long domainId) { - List dataSetModelConfigs = Lists.newArrayList(); - List modelByDomainIds = - modelService.getModelByDomainIds(Lists.newArrayList(domainId)); - - for (ModelResp modelResp : modelByDomainIds) { - DataSetModelConfig dataSetModelConfig = new DataSetModelConfig(); - dataSetModelConfig.setId(modelResp.getId()); - MetaFilter metaFilter = new MetaFilter(); - metaFilter.setModelIds(Lists.newArrayList(modelResp.getId())); - List metrics = metricService.getMetrics(metaFilter) - .stream().map(MetricResp::getId).collect(Collectors.toList()); - dataSetModelConfig.setMetrics(metrics); - List dimensions = dimensionService.getDimensions(metaFilter) - .stream().map(DimensionResp::getId).collect(Collectors.toList()); - dataSetModelConfig.setMetrics(metrics); - dataSetModelConfig.setDimensions(dimensions); - dataSetModelConfigs.add(dataSetModelConfig); - } - return dataSetModelConfigs; + private TagObjectResp addTagObjectUser(DomainResp s2Domain) throws Exception { + TagObjectReq tagObjectReq = new TagObjectReq(); + tagObjectReq.setDomainId(s2Domain.getId()); + tagObjectReq.setName("用户"); + tagObjectReq.setBizName("user"); + User user = User.getFakeUser(); + return tagObjectService.create(tagObjectReq, user); } } diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 098dbfe37..e008861f0 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -9,7 +9,6 @@ spring: h2: console: path: /h2-console/semantic - # enabled web enabled: true datasource: driver-class-name: org.h2.Driver @@ -31,16 +30,10 @@ authentication: header: key: Authorization -demo: - enabled: true - query: optimizer: enable: true -multi: - turn: false - time: threshold: 100 @@ -70,6 +63,11 @@ text2sql: num: 1 s2: + demo: + names: S2VisitsDemo,S2ArtistDemo + enableLLM: true + multi-turn: + enable: false langchain4j: #1.chat-model chat-model: diff --git a/launchers/standalone/src/test/resources/application-local.yaml b/launchers/standalone/src/test/resources/application-local.yaml index 10bf0b86b..e008861f0 100644 --- a/launchers/standalone/src/test/resources/application-local.yaml +++ b/launchers/standalone/src/test/resources/application-local.yaml @@ -1,8 +1,14 @@ +server: + port: 9080 + compression: + enabled: true + min-response-size: 1024 + mime-types: application/javascript,application/json,application/xml,text/html,text/xml,text/plain,text/css,image/* + spring: h2: console: - path: /h2-console/chat - # enabled web + path: /h2-console/semantic enabled: true datasource: driver-class-name: org.h2.Driver @@ -12,32 +18,56 @@ spring: username: root password: semantic -demo: - enabled: true - nl2SqlLlm: - enabled: false - -server: - port: 9080 +mybatis: + mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml authentication: - enable: false + enable: true exclude: path: /api/auth/user/register,/api/auth/user/login + token: + http: + header: + key: Authorization -semantic: - url: - prefix: http://127.0.0.1:9081 +query: + optimizer: + enable: true time: threshold: 100 -mybatis: - mapper-locations=classpath:mappers/custom/*.xml,classpath*:/mappers/*.xml +dimension: + topn: 20 +metric: + topn: 20 +corrector: + additional: + information: true + +pyllm: + url: http://127.0.0.1:9092 +llm: + parser: + url: ${pyllm.url} + +embedding: + url: ${pyllm.url} + +functionCall: + url: ${pyllm.url} + +text2sql: + example: + num: 1 -#langchain4j config s2: + demo: + names: S2VisitsDemo,S2ArtistDemo + enableLLM: true + multi-turn: + enable: false langchain4j: #1.chat-model chat-model: @@ -58,6 +88,7 @@ s2: # inProcess: # modelPath: /data/model.onnx # vocabularyPath: /data/onnx_vocab.txt + # shibing624/text2vec-base-chinese #2.2 open_ai # embedding-model: # provider: open_ai