From cf1b5336c3d9d5ac4f70506c5b47c0fb526fe68c Mon Sep 17 00:00:00 2001 From: mainmain <57514971+mainmainer@users.noreply.github.com> Date: Sun, 20 Aug 2023 17:30:35 +0800 Subject: [PATCH] [improvement](supersonic) based on version 0.7.2 (#34) Co-authored-by: zuopengge --- .../request/QueryAuthResReq.java | 2 - .../authorization/service/AuthService.java | 3 +- .../auth/authentication/config/TppConfig.java | 21 - .../application/AuthServiceImpl.java | 29 +- .../authorization/rest/AuthController.java | 12 +- .../chat/api/component/DSLOptimizer.java | 9 + .../chat/api/pojo/CorrectionInfo.java | 21 + .../chat/api/pojo/request/QueryReq.java | 1 + chat/core/pom.xml | 6 - .../tencent/supersonic/chat/agent/Agent.java | 43 + .../supersonic/chat/agent/AgentConfig.java | 18 + .../supersonic/chat/agent/tool/AgentTool.java | 19 + .../chat/agent/tool/AgentToolType.java | 8 + .../supersonic/chat/agent/tool/DslTool.java | 14 + .../chat/agent/tool/MetricInterpretTool.java | 16 + .../chat/agent/tool/PluginTool.java | 13 + .../chat/agent/tool/RuleQueryTool.java | 13 + .../supersonic/chat/mapper/MapperHelper.java | 4 +- .../chat/mapper/QueryMatchStrategy.java | 39 +- .../chat/parser/SatisfactionChecker.java | 7 +- .../embedding/EmbeddingBasedParser.java | 54 +- .../parser/function/FunctionBasedParser.java | 58 +- .../function/HeuristicModelResolver.java | 39 +- .../chat/parser/function/ModelResolver.java | 3 +- .../chat/parser/llm/DSLParseResult.java | 11 - .../chat/parser/llm/dsl/DSLDateHelper.java | 20 + .../chat/parser/llm/dsl/DSLParseResult.java | 17 + .../parser/llm/{ => dsl}/LLMDSLParser.java | 80 +- .../llm/interpret/MetricInterpretParser.java | 144 +++ .../parser/llm/interpret/MetricOption.java | 14 + .../{ => time}/LLMTimeEnhancementParse.java | 19 +- .../chat/parser/rule/AgentCheckParser.java | 63 + .../chat/parser/rule/QueryModeParser.java | 9 +- .../chat/persistence/dataobject/AgentDO.java | 236 ++++ .../dataobject/AgentDOExample.java | 1025 +++++++++++++++++ .../persistence/mapper/AgentDOMapper.java | 71 ++ .../repository/AgentRepository.java | 18 + .../repository/impl/AgentRepositoryImpl.java | 43 + .../impl/ChatContextRepositoryImpl.java | 6 +- .../supersonic/chat/plugin/PluginManager.java | 76 +- .../ContentInterpretQuery.java | 149 --- .../chat/query/HeuristicQuerySelector.java | 3 +- .../supersonic/chat/query/QuerySelector.java | 4 +- .../supersonic/chat/query/dsl/DSLBuilder.java | 78 -- .../supersonic/chat/query/dsl/DSLQuery.java | 42 +- .../query/dsl/optimizer/BaseDSLOptimizer.java | 33 + .../dsl/optimizer/DateFieldCorrector.java | 26 + .../query/dsl/optimizer/FieldCorrector.java | 17 + .../dsl/optimizer/FunctionCorrector.java | 16 + .../dsl/optimizer/QueryFilterAppend.java | 48 + .../optimizer/SelectFieldAppendCorrector.java | 35 + .../dsl/optimizer/TableNameCorrector.java | 21 + .../LLmAnswerReq.java | 2 +- .../LLmAnswerResp.java | 2 +- .../metricInterpret/MetricInterpretQuery.java | 143 +++ .../query/plugin/webpage/WebPageQuery.java | 21 +- .../chat/query/rule/RuleSemanticQuery.java | 3 +- .../query/rule/entity/EntityFilterQuery.java | 9 +- .../chat/query/rule/entity/EntityIdQuery.java | 23 + .../supersonic/chat/rest/AgentController.java | 51 + .../chat/rest/ChatConfigController.java | 2 +- .../supersonic/chat/rest/ChatController.java | 2 +- .../chat/rest/ChatQueryController.java | 2 +- .../chat/rest/RecommendController.java | 2 +- .../supersonic/chat/service/AgentService.java | 19 + .../chat/service/impl/AgentServiceImpl.java | 82 ++ .../chat/service/impl/QueryServiceImpl.java | 19 +- .../service/impl/RecommendServiceImpl.java | 2 + .../chat/service/impl/SearchServiceImpl.java | 26 +- .../supersonic/chat/utils/ChatGptHelper.java | 77 -- .../chat/utils/ComponentFactory.java | 15 +- .../supersonic/chat/utils/NatureHelper.java | 4 +- chat/core/src/main/python/bin/env.sh | 1 - .../main/resources/mapper/AgentDOMapper.xml | 303 +++++ chat/core/src/main/resources/sql.ddl/chat.sql | 16 +- .../dsl/optimizer/DateFieldCorrectorTest.java | 37 + .../SelectFieldAppendCorrectorTest.java | 25 + .../builder/DimensionWordBuilder.java | 19 +- .../dictionary/builder/MetricWordBuilder.java | 19 +- .../semantic/LocalSemanticLayer.java | 27 +- .../semantic/ModelSchemaBuilder.java | 45 +- common/pom.xml | 6 + .../supersonic/common/util/ChatGptHelper.java | 130 +++ .../supersonic/common/util/DateUtils.java | 37 +- .../supersonic/common/util/StringUtil.java | 1 - .../util/jsqlparser/CCJSqlParserUtils.java | 26 +- .../util/jsqlparser/FieldReplaceVisitor.java | 48 - .../jsqlparser/FunctionReplaceVisitor.java | 171 +++ .../jsqlparser/GroupByReplaceVisitor.java | 17 +- .../util/jsqlparser/ParseVisitorHelper.java | 99 -- .../supersonic/common/util/DateUtilsTest.java | 12 + .../jsqlparser/CCJSqlParserUtilsTest.java | 77 +- .../main/resources/META-INF/spring.factories | 14 +- .../com/tencent/supersonic/ConfigureDemo.java | 72 +- .../main/resources/META-INF/spring.factories | 11 +- .../src/main/resources/db/schema-h2.sql | 16 + .../supersonic/integration/BaseQueryTest.java | 9 + .../integration/MetricInterpretTest.java | 56 + .../integration/MetricQueryTest.java | 46 +- ...figuration.java => MockConfiguration.java} | 20 +- .../plugin/PluginRecognizeTest.java | 39 +- .../tencent/supersonic/util/DataUtils.java | 70 +- .../test/resources/META-INF/spring.factories | 4 +- .../model/application/CatalogImpl.java | 30 +- .../application/DimensionServiceImpl.java | 82 +- .../model/application/DomainServiceImpl.java | 15 +- .../model/application/MetricServiceImpl.java | 27 +- .../model/domain/DimensionService.java | 6 + .../semantic/model/domain/MetricService.java | 2 + .../domain/utils/DimensionConverter.java | 2 + .../model/rest/DimensionController.java | 25 +- .../semantic/model/rest/MetricController.java | 18 +- .../parser/convert/CalculateAggConverter.java | 2 +- .../convert/ParserDefaultConverter.java | 21 +- .../semantic/query/rest/QueryController.java | 2 +- .../semantic/query/service/QueryService.java | 7 +- .../query/service/QueryServiceImpl.java | 12 +- .../query/utils/DataPermissionAOP.java | 45 +- .../semantic/query/utils/DimValueAspect.java | 3 +- .../query/utils/QueryReqConverter.java | 10 + .../query/utils/QueryStructUtils.java | 32 +- .../semantic/query/utils/QueryUtils.java | 24 +- 122 files changed, 4045 insertions(+), 1075 deletions(-) delete mode 100644 auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/config/TppConfig.java create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/component/DSLOptimizer.java create mode 100644 chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/CorrectionInfo.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/DslTool.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/DSLParseResult.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLDateHelper.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLParseResult.java rename chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/{ => dsl}/LLMDSLParser.java (82%) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricOption.java rename chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/{ => time}/LLMTimeEnhancementParse.java (69%) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDO.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDOExample.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/persistence/mapper/AgentDOMapper.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/AgentRepository.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/AgentRepositoryImpl.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/ContentInterpretQuery.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLBuilder.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/BaseDSLOptimizer.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrector.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FieldCorrector.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FunctionCorrector.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/QueryFilterAppend.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrector.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/TableNameCorrector.java rename chat/core/src/main/java/com/tencent/supersonic/chat/query/{ContentInterpret => metricInterpret}/LLmAnswerReq.java (67%) rename chat/core/src/main/java/com/tencent/supersonic/chat/query/{ContentInterpret => metricInterpret}/LLmAnswerResp.java (62%) create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/MetricInterpretQuery.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityIdQuery.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/rest/AgentController.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java create mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java delete mode 100644 chat/core/src/main/java/com/tencent/supersonic/chat/utils/ChatGptHelper.java create mode 100644 chat/core/src/main/resources/mapper/AgentDOMapper.xml create mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrectorTest.java create mode 100644 chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrectorTest.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/ChatGptHelper.java create mode 100644 common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java create mode 100644 launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java rename launchers/standalone/src/test/java/com/tencent/supersonic/integration/{plugin/PluginMockConfiguration.java => MockConfiguration.java} (79%) diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/request/QueryAuthResReq.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/request/QueryAuthResReq.java index a006769e9..9eed07aef 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/request/QueryAuthResReq.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/request/QueryAuthResReq.java @@ -10,8 +10,6 @@ import lombok.ToString; @ToString public class QueryAuthResReq { - private String user; - private List departmentIds = new ArrayList<>(); private List resources; diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java index 42b478f0e..60415954f 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.auth.api.authorization.service; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; @@ -14,5 +15,5 @@ public interface AuthService { void removeAuthGroup(AuthGroup group); - AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, HttpServletRequest request); + AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user); } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/config/TppConfig.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/config/TppConfig.java deleted file mode 100644 index 474de6004..000000000 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/config/TppConfig.java +++ /dev/null @@ -1,21 +0,0 @@ -package com.tencent.supersonic.auth.authentication.config; - - -import lombok.Data; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Configuration; - -@Data -@Configuration -public class TppConfig { - - @Value(value = "${auth.app.secret:}") - private String appSecret; - - @Value(value = "${auth.app.key:}") - private String appKey; - - @Value(value = "${auth.oa.url:}") - private String tppOaUrl; - -} diff --git a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java index 5645a8077..2f5cda273 100644 --- a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java +++ b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java @@ -2,27 +2,24 @@ package com.tencent.supersonic.auth.authorization.application; import com.google.common.base.Strings; import com.google.gson.Gson; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.service.UserService; -import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes; import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp; -import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter; import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.auth.api.authorization.service.AuthService; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; -import javax.servlet.http.HttpServletRequest; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; +import com.tencent.supersonic.common.util.S2ThreadContext; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; +import java.util.*; +import java.util.stream.Collectors; @Service @Slf4j @@ -33,7 +30,7 @@ public class AuthServiceImpl implements AuthService { private UserService userService; public AuthServiceImpl(JdbcTemplate jdbcTemplate, - UserService userService) { + UserService userService) { this.jdbcTemplate = jdbcTemplate; this.userService = userService; } @@ -78,12 +75,12 @@ public class AuthServiceImpl implements AuthService { @Override - public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, HttpServletRequest request) { - Set userOrgIds = userService.getUserAllOrgId(req.getUser()); + public AuthorizedResourceResp queryAuthorizedResources(QueryAuthResReq req, User user) { + Set userOrgIds = userService.getUserAllOrgId(user.getName()); if (!CollectionUtils.isEmpty(userOrgIds)) { req.setDepartmentIds(new ArrayList<>(userOrgIds)); } - List groups = getAuthGroups(req); + List groups = getAuthGroups(req, user.getName()); AuthorizedResourceResp resource = new AuthorizedResourceResp(); Map> authGroupsByModelId = groups.stream() .collect(Collectors.groupingBy(AuthGroup::getModelId)); @@ -130,14 +127,14 @@ public class AuthServiceImpl implements AuthService { return resource; } - private List getAuthGroups(QueryAuthResReq req) { + private List getAuthGroups(QueryAuthResReq req, String userName) { List groups = load().stream() .filter(group -> { if (!Objects.equals(group.getModelId(), req.getModelId())) { return false; } if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) && group.getAuthorizedUsers() - .contains(req.getUser())) { + .contains(userName)) { return true; } for (String departmentId : req.getDepartmentIds()) { @@ -148,7 +145,7 @@ public class AuthServiceImpl implements AuthService { } return false; }).collect(Collectors.toList()); - log.info("user:{} department:{} authGroups:{}", req.getUser(), req.getDepartmentIds(), groups); + log.info("user:{} department:{} authGroups:{}", userName, req.getDepartmentIds(), groups); return groups; } diff --git a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java index bede33b4c..7629ca64c 100644 --- a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java +++ b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java @@ -1,11 +1,14 @@ package com.tencent.supersonic.auth.authorization.rest; -import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq; import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; import com.tencent.supersonic.auth.api.authorization.service.AuthService; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; import java.util.List; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import lombok.extern.slf4j.Slf4j; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; @@ -62,12 +65,13 @@ public class AuthController { * 查询有权限访问的受限资源id * * @param req - * @param request * @return */ @PostMapping("/queryAuthorizedRes") public AuthorizedResourceResp queryAuthorizedResources(@RequestBody QueryAuthResReq req, - HttpServletRequest request) { - return authService.queryAuthorizedResources(req, request); + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + return authService.queryAuthorizedResources(req, user); } } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/DSLOptimizer.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/DSLOptimizer.java new file mode 100644 index 000000000..2ad042f6b --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/DSLOptimizer.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.chat.api.component; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import net.sf.jsqlparser.JSQLParserException; + +public interface DSLOptimizer { + CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException; + +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/CorrectionInfo.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/CorrectionInfo.java new file mode 100644 index 000000000..d43d0e442 --- /dev/null +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/CorrectionInfo.java @@ -0,0 +1,21 @@ +package com.tencent.supersonic.chat.api.pojo; + +import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +@Builder +public class CorrectionInfo { + + private QueryFilters queryFilters; + + private SemanticParseInfo parseInfo; + + private String sql; + +} diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java index 82f28358b..4a127ea9f 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryReq.java @@ -12,4 +12,5 @@ public class QueryReq { private User user; private QueryFilters queryFilters; private boolean saveAnswer = true; + private Integer agentId; } diff --git a/chat/core/pom.xml b/chat/core/pom.xml index 2e9314c4a..70b4c61c7 100644 --- a/chat/core/pom.xml +++ b/chat/core/pom.xml @@ -40,12 +40,6 @@ compile - - com.github.plexpt - chatgpt - 4.1.2 - - org.junit.jupiter junit-jupiter diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java new file mode 100644 index 000000000..f8cc93925 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/Agent.java @@ -0,0 +1,43 @@ +package com.tencent.supersonic.chat.agent; + + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.common.pojo.RecordInfo; +import lombok.Data; +import org.springframework.util.CollectionUtils; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Data +public class Agent extends RecordInfo { + + private Integer id; + private Integer enableSearch; + private String name; + private String description; + + //0 offline, 1 online + private Integer status; + private List examples; + private String agentConfig; + + public List getTools(AgentToolType type) { + Map map = JSONObject.parseObject(agentConfig, Map.class); + if (CollectionUtils.isEmpty(map) || map.get("tools") == null) { + return Lists.newArrayList(); + } + List toolList = (List) map.get("tools"); + return toolList.stream() + .filter(tool -> type.name().equals(tool.get("type"))) + .map(JSONObject::toJSONString) + .collect(Collectors.toList()); + } + + public boolean enableSearch() { + return enableSearch != null && enableSearch == 1; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java new file mode 100644 index 000000000..9f675cead --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/AgentConfig.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.chat.agent; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.agent.tool.AgentTool; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; +import java.util.List; + + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AgentConfig { + + List tools = Lists.newArrayList(); + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java new file mode 100644 index 000000000..de51051c3 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentTool.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.chat.agent.tool; + + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AgentTool { + + private String name; + + private AgentToolType type; + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java new file mode 100644 index 000000000..20500a644 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/AgentToolType.java @@ -0,0 +1,8 @@ +package com.tencent.supersonic.chat.agent.tool; + +public enum AgentToolType { + RULE, + DSL, + PLUGIN, + INTERPRET +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/DslTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/DslTool.java new file mode 100644 index 000000000..f4e64a2c1 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/DslTool.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.chat.agent.tool; + +import lombok.Data; + +import java.util.List; + +@Data +public class DslTool extends AgentTool { + + private List modelIds; + + private List exampleQuestions; + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java new file mode 100644 index 000000000..4c71f8a87 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/MetricInterpretTool.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.chat.agent.tool; + +import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption; +import lombok.Data; + +import java.util.List; + + +@Data +public class MetricInterpretTool extends AgentTool { + + private Long modelId; + + private List metricOptions; + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java new file mode 100644 index 000000000..8ccb2671e --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/PluginTool.java @@ -0,0 +1,13 @@ +package com.tencent.supersonic.chat.agent.tool; + + +import lombok.Data; + +import java.util.List; + +@Data +public class PluginTool extends AgentTool { + + private List plugins; + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java new file mode 100644 index 000000000..63e8bdd49 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/agent/tool/RuleQueryTool.java @@ -0,0 +1,13 @@ +package com.tencent.supersonic.chat.agent.tool; + + +import lombok.Data; + +import java.util.List; + +@Data +public class RuleQueryTool extends AgentTool { + + private List queryModes; + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java index 0a3787153..2366f1cc7 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/MapperHelper.java @@ -19,7 +19,7 @@ import org.springframework.stereotype.Service; @Slf4j public class MapperHelper { - @Value("${one.detection.size:6}") + @Value("${one.detection.size:8}") private Integer oneDetectionSize; @Value("${one.detection.max.size:20}") private Integer oneDetectionMaxSize; @@ -64,7 +64,7 @@ public class MapperHelper { */ public boolean existDimensionValues(List natures) { for (String nature : natures) { - if (NatureHelper.isDimensionValueClassId(nature)) { + if (NatureHelper.isDimensionValueModelId(nature)) { return true; } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java index 8e1a085dc..f3507cc36 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/mapper/QueryMatchStrategy.java @@ -32,7 +32,7 @@ public class QueryMatchStrategy implements MatchStrategy { private MapperHelper mapperHelper; @Override - public Map> match(String text, List terms, Long detectmodelId) { + public Map> match(String text, List terms, Long detectModelId) { if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; } @@ -43,22 +43,18 @@ public class QueryMatchStrategy implements MatchStrategy { List offsetList = terms.stream().sorted(Comparator.comparing(Term::getOffset)) .map(term -> term.getOffset()).collect(Collectors.toList()); - log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectmodelId:{}", terms, - regOffsetToLength, offsetList, detectmodelId); + log.debug("retryCount:{},terms:{},regOffsetToLength:{},offsetList:{},detectModelId:{}", terms, + regOffsetToLength, offsetList, detectModelId); - List detects = detect(text, regOffsetToLength, offsetList, detectmodelId); + List detects = detect(text, regOffsetToLength, offsetList, detectModelId); Map> result = new HashMap<>(); - MatchText matchText = MatchText.builder() - .regText(text) - .detectSegment(text) - .build(); - result.put(matchText, detects); + result.put(MatchText.builder().regText(text).detectSegment(text).build(), detects); return result; } private List detect(String text, Map regOffsetToLength, List offsetList, - Long detectmodelId) { + Long detectModelId) { List results = Lists.newArrayList(); for (Integer index = 0; index <= text.length() - 1; ) { @@ -69,7 +65,7 @@ public class QueryMatchStrategy implements MatchStrategy { int offset = mapperHelper.getStepOffset(offsetList, index); i = mapperHelper.getStepIndex(regOffsetToLength, i); if (i <= text.length()) { - List mapResults = detectByStep(text, detectmodelId, index, i, offset); + List mapResults = detectByStep(text, detectModelId, index, i, offset); selectMapResultInOneRound(mapResultRowSet, mapResults); } } @@ -106,15 +102,15 @@ public class QueryMatchStrategy implements MatchStrategy { return a.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, a.getNatures()); } - private List detectByStep(String text, Long detectmodelId, Integer index, Integer i, int offset) { + private List detectByStep(String text, Long detectModelId, Integer index, Integer i, int offset) { String detectSegment = text.substring(index, i); - Integer oneDetectionSize = mapperHelper.getOneDetectionSize(); + // step1. pre search - LinkedHashSet mapResults = SearchService.prefixSearch(detectSegment, - mapperHelper.getOneDetectionMaxSize()) + Integer oneDetectionMaxSize = mapperHelper.getOneDetectionMaxSize(); + LinkedHashSet mapResults = SearchService.prefixSearch(detectSegment, oneDetectionMaxSize) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); // step2. suffix search - LinkedHashSet suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionSize) + LinkedHashSet suffixMapResults = SearchService.suffixSearch(detectSegment, oneDetectionMaxSize) .stream().collect(Collectors.toCollection(LinkedHashSet::new)); mapResults.addAll(suffixMapResults); @@ -126,11 +122,11 @@ public class QueryMatchStrategy implements MatchStrategy { mapResults = mapResults.stream().sorted((a, b) -> -(b.getName().length() - a.getName().length())) .collect(Collectors.toCollection(LinkedHashSet::new)); // step4. filter by classId - if (Objects.nonNull(detectmodelId) && detectmodelId > 0) { - log.debug("detectmodelId:{}, before parseResults:{}", mapResults); + if (Objects.nonNull(detectModelId) && detectModelId > 0) { + log.debug("detectModelId:{}, before parseResults:{}", mapResults); mapResults = mapResults.stream().map(entry -> { List natures = entry.getNatures().stream().filter( - nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectmodelId) || (nature.startsWith( + nature -> nature.startsWith(DictWordType.NATURE_SPILT + detectModelId) || (nature.startsWith( DictWordType.NATURE_SPILT)) ).collect(Collectors.toList()); entry.setNatures(natures); @@ -145,8 +141,7 @@ public class QueryMatchStrategy implements MatchStrategy { .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) .collect(Collectors.toCollection(LinkedHashSet::new)); - log.debug("metricDimensionThreshold:{},dimensionValueThreshold:{},after isSimilarity parseResults:{}", - mapResults); + log.debug("after isSimilarity parseResults:{}", mapResults); mapResults = mapResults.stream().map(parseResult -> { parseResult.setOffset(offset); @@ -165,7 +160,7 @@ public class QueryMatchStrategy implements MatchStrategy { if (CollectionUtils.isNotEmpty(dimensionMetrics)) { return dimensionMetrics; } else { - return mapResults.stream().limit(oneDetectionSize).collect(Collectors.toList()); + return mapResults.stream().limit(mapperHelper.getOneDetectionSize()).collect(Collectors.toList()); } } } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/SatisfactionChecker.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/SatisfactionChecker.java index f1ab63981..e678714ee 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/SatisfactionChecker.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/SatisfactionChecker.java @@ -2,8 +2,8 @@ package com.tencent.supersonic.chat.parser; import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.*; +import com.tencent.supersonic.chat.query.dsl.DSLQuery; import lombok.extern.slf4j.Slf4j; /** @@ -21,6 +21,9 @@ public class SatisfactionChecker { // check all the parse info in candidate public static boolean check(QueryContext queryContext) { for (SemanticQuery query : queryContext.getCandidateQueries()) { + if (query.getQueryMode().equals(DSLQuery.QUERY_MODE)) { + continue; + } if (checkThreshold(queryContext.getRequest().getQueryText(), query.getParseInfo())) { return true; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/embedding/EmbeddingBasedParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/embedding/EmbeddingBasedParser.java index 87ff3f812..48e9ea74d 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/embedding/EmbeddingBasedParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/embedding/EmbeddingBasedParser.java @@ -2,14 +2,7 @@ package com.tencent.supersonic.chat.parser.embedding; import com.google.common.collect.Lists; import com.tencent.supersonic.chat.api.component.SemanticParser; -import com.tencent.supersonic.chat.api.pojo.ChatContext; -import com.tencent.supersonic.chat.api.pojo.ModelSchema; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; -import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.parser.ParseMode; @@ -17,17 +10,14 @@ import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; -import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.stream.Collectors; +import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; @@ -53,48 +43,32 @@ public class EmbeddingBasedParser implements SemanticParser { if (CollectionUtils.isEmpty(embeddingRetrievals)) { return; } - PluginService pluginService = ContextUtils.getBean(PluginService.class); - List plugins = pluginService.getPluginList(); + List plugins = getPluginList(queryContext); Map pluginMap = plugins.stream().collect(Collectors.toMap(Plugin::getId, p -> p)); for (RecallRetrieval embeddingRetrieval : embeddingRetrievals) { Plugin plugin = pluginMap.get(Long.parseLong(embeddingRetrieval.getId())); - if (plugin == null) { + if (plugin == null || DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) { continue; } - Pair> pair = PluginManager.resolve(plugin, queryContext); + Pair> pair = PluginManager.resolve(plugin, queryContext); log.info("embedding plugin resolve: {}", pair); if (pair.getLeft()) { - List modelList = pair.getRight(); + Set modelList = pair.getRight(); if (CollectionUtils.isEmpty(modelList)) { return; } - modelList = distinctModelList(plugin, queryContext.getMapInfo(), modelList); for (Long modelId : modelList) { buildQuery(plugin, Double.parseDouble(embeddingRetrieval.getDistance()), modelId, queryContext, queryContext.getMapInfo().getMatchedElements(modelId)); + if (plugin.isContainsAllModel()) { + break; + } } return; } } } - public List distinctModelList(Plugin plugin, SchemaMapInfo schemaMapInfo, List modelList) { - if (!plugin.isContainsAllModel()) { - return modelList; - } - boolean noElementMatch = true; - for (Long model : modelList) { - List schemaElementMatches = schemaMapInfo.getMatchedElements(model); - if (!CollectionUtils.isEmpty(schemaElementMatches)) { - noElementMatch = false; - } - } - if (noElementMatch) { - return modelList.subList(0, 1); - } - return modelList; - } - private void buildQuery(Plugin plugin, double distance, Long modelId, QueryContext queryContext, List schemaElementMatches) { log.info("EmbeddingBasedParser Model: {} choose plugin: [{} {}]", modelId, plugin.getId(), plugin.getName()); @@ -126,6 +100,8 @@ public class EmbeddingBasedParser implements SemanticParser { pluginParseResult.setRequest(queryReq); pluginParseResult.setDistance(distance); properties.put(Constants.CONTEXT, pluginParseResult); + properties.put("type", "plugin"); + properties.put("name", plugin.getName()); semanticParseInfo.setProperties(properties); semanticParseInfo.setScore(distance); fillSemanticParseInfo(semanticParseInfo); @@ -176,4 +152,8 @@ public class EmbeddingBasedParser implements SemanticParser { } } + protected List getPluginList(QueryContext queryContext) { + return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId()); + } + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/FunctionBasedParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/FunctionBasedParser.java index 702af0473..f95732d18 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/FunctionBasedParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/FunctionBasedParser.java @@ -14,23 +14,16 @@ import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.plugin.PluginParseConfig; import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.query.QueryManager; -import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; +import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.JsonUtil; import java.net.URI; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.stream.Collectors; +import com.tencent.supersonic.common.util.JsonUtil; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; @@ -62,6 +55,10 @@ public class FunctionBasedParser implements SemanticParser { return; } List functionDOList = getFunctionDO(queryCtx.getRequest().getModelId(), queryCtx); + if (CollectionUtils.isEmpty(functionDOList)) { + log.info("function call parser, plugin is empty, skip"); + return; + } FunctionReq functionReq = FunctionReq.builder() .queryText(queryCtx.getRequest().getQueryText()) .pluginConfigs(functionDOList).build(); @@ -85,7 +82,7 @@ public class FunctionBasedParser implements SemanticParser { PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(toolSelection); ModelResolver ModelResolver = ComponentFactory.getModelResolver(); log.info("plugin ModelList:{}", plugin.getModelList()); - Pair> pluginResolveResult = PluginManager.resolve(plugin, queryCtx); + Pair> pluginResolveResult = PluginManager.resolve(plugin, queryCtx); Long modelId = ModelResolver.resolve(queryCtx, chatCtx, pluginResolveResult.getRight()); log.info("FunctionBasedParser modelId:{}", modelId); if ((Objects.isNull(modelId) || modelId <= 0) && !plugin.isContainsAllModel()) { @@ -102,6 +99,8 @@ public class FunctionBasedParser implements SemanticParser { functionCallParseResult.setRequest(queryCtx.getRequest()); Map properties = new HashMap<>(); properties.put(Constants.CONTEXT, functionCallParseResult); + properties.put("type", "plugin"); + properties.put("name", plugin.getName()); parseInfo.setProperties(properties); parseInfo.setScore(FUNCTION_BONUS_THRESHOLD); parseInfo.setQueryMode(semanticQuery.getQueryMode()); @@ -112,17 +111,6 @@ public class FunctionBasedParser implements SemanticParser { queryCtx.getCandidateQueries().add(semanticQuery); } - - private Set getMatchModels(QueryContext queryCtx) { - Set result = new HashSet<>(); - Long modelId = queryCtx.getRequest().getModelId(); - if (Objects.nonNull(modelId) && modelId > 0) { - result.add(modelId); - return result; - } - return queryCtx.getMapInfo().getMatchedModels(); - } - private boolean skipFunction(QueryContext queryCtx, FunctionResp functionResp) { if (Objects.isNull(functionResp) || StringUtils.isBlank(functionResp.getToolSelection())) { return true; @@ -140,7 +128,7 @@ public class FunctionBasedParser implements SemanticParser { private List getFunctionDO(Long modelId, QueryContext queryContext) { log.info("user decide Model:{}", modelId); - List plugins = PluginManager.getPlugins(); + List plugins = getPluginList(queryContext); List functionDOList = plugins.stream().filter(plugin -> { if (DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) { return false; @@ -153,12 +141,12 @@ public class FunctionBasedParser implements SemanticParser { if (StringUtils.isBlank(pluginParseConfig.getName())) { return false; } - Pair> pluginResolverResult = PluginManager.resolve(plugin, queryContext); - log.info("embedding plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult); + Pair> pluginResolverResult = PluginManager.resolve(plugin, queryContext); + log.info("plugin [{}-{}] resolve: {}", plugin.getId(), plugin.getName(), pluginResolverResult); if (!pluginResolverResult.getLeft()) { return false; } else { - List resolveModel = pluginResolverResult.getRight(); + Set resolveModel = pluginResolverResult.getRight(); if (modelId != null && modelId > 0) { if (plugin.isContainsAllModel()) { return true; @@ -172,20 +160,6 @@ public class FunctionBasedParser implements SemanticParser { return functionDOList; } - private List getFunctionNames(Set matchedModels) { - List plugins = PluginManager.getPlugins(); - Set functionNames = plugins.stream() - .filter(entry -> { - if (!CollectionUtils.isEmpty(entry.getModelList()) && !CollectionUtils.isEmpty(matchedModels)) { - return entry.getModelList().stream().anyMatch(matchedModels::contains); - } - return true; - } - ).map(Plugin::getName).collect(Collectors.toSet()); - functionNames.add(DSLQuery.QUERY_MODE); - return new ArrayList<>(functionNames); - } - public FunctionResp requestFunction(String url, FunctionReq functionReq) { HttpHeaders headers = new HttpHeaders(); long startTime = System.currentTimeMillis(); @@ -205,4 +179,8 @@ public class FunctionBasedParser implements SemanticParser { } return null; } + + protected List getPluginList(QueryContext queryContext) { + return PluginManager.getPluginAgentCanSupport(queryContext.getRequest().getAgentId()); + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/HeuristicModelResolver.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/HeuristicModelResolver.java index 4e77a1e4a..e7c447258 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/HeuristicModelResolver.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/HeuristicModelResolver.java @@ -1,21 +1,12 @@ package com.tencent.supersonic.chat.parser.function; -import com.tencent.supersonic.chat.api.component.SemanticQuery; -import com.tencent.supersonic.chat.api.pojo.ChatContext; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; -import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; +import com.tencent.supersonic.chat.api.component.SemanticQuery; import lombok.extern.slf4j.Slf4j; + +import java.util.*; +import java.util.stream.Collectors; import org.apache.commons.collections.CollectionUtils; @Slf4j @@ -55,7 +46,7 @@ public class HeuristicModelResolver implements ModelResolver { * @return false will use context Model, true will use other Model , maybe include context Model */ protected static boolean isAllowSwitch(Map ModelQueryModes, SchemaMapInfo schemaMap, - ChatContext chatCtx, QueryReq searchCtx, Long modelId, List restrictiveModels) { + ChatContext chatCtx, QueryReq searchCtx, Long modelId, Set restrictiveModels) { if (!Objects.nonNull(modelId) || modelId <= 0) { return true; } @@ -80,8 +71,7 @@ public class HeuristicModelResolver implements ModelResolver { } if (searchCtx.getQueryText() != null && semanticParseInfo.getDateInfo() != null) { if (semanticParseInfo.getDateInfo().getDetectWord() != null) { - if (semanticParseInfo.getDateInfo().getDetectWord() - .equalsIgnoreCase(searchCtx.getQueryText())) { + if (semanticParseInfo.getDateInfo().getDetectWord().equalsIgnoreCase(searchCtx.getQueryText())) { log.info("timeParseResults is not null , can not switch context , timeParseResults:{},", semanticParseInfo.getDateInfo()); return false; @@ -131,7 +121,7 @@ public class HeuristicModelResolver implements ModelResolver { } - public Long resolve(QueryContext queryContext, ChatContext chatCtx, List restrictiveModels) { + public Long resolve(QueryContext queryContext, ChatContext chatCtx, Set restrictiveModels) { Long modelId = queryContext.getRequest().getModelId(); if (Objects.nonNull(modelId) && modelId > 0) { if (CollectionUtils.isNotEmpty(restrictiveModels) && restrictiveModels.contains(modelId)) { @@ -151,17 +141,16 @@ public class HeuristicModelResolver implements ModelResolver { for (Long matchedModel : matchedModels) { ModelQueryModes.put(matchedModel, null); } - if (ModelQueryModes.size() == 1) { + if(ModelQueryModes.size()==1){ return ModelQueryModes.keySet().stream().findFirst().get(); } return resolve(ModelQueryModes, queryContext, chatCtx, - queryContext.getMapInfo(), restrictiveModels); + queryContext.getMapInfo(),restrictiveModels); } public Long resolve(Map ModelQueryModes, QueryContext queryContext, - ChatContext chatCtx, SchemaMapInfo schemaMap, List restrictiveModels) { - Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap, - restrictiveModels); + ChatContext chatCtx, SchemaMapInfo schemaMap, Set restrictiveModels) { + Long selectModel = selectModel(ModelQueryModes, queryContext.getRequest(), chatCtx, schemaMap,restrictiveModels); if (selectModel > 0) { log.info("selectModel {} ", selectModel); return selectModel; @@ -172,7 +161,7 @@ public class HeuristicModelResolver implements ModelResolver { public Long selectModel(Map ModelQueryModes, QueryReq queryContext, ChatContext chatCtx, - SchemaMapInfo schemaMap, List restrictiveModels) { + SchemaMapInfo schemaMap, Set restrictiveModels) { // if QueryContext has modelId and in ModelQueryModes if (ModelQueryModes.containsKey(queryContext.getModelId())) { log.info("selectModel from QueryContext [{}]", queryContext.getModelId()); @@ -181,7 +170,7 @@ public class HeuristicModelResolver implements ModelResolver { // if ChatContext has modelId and in ModelQueryModes if (chatCtx.getParseInfo().getModelId() > 0) { Long modelId = chatCtx.getParseInfo().getModelId(); - if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId, restrictiveModels)) { + if (!isAllowSwitch(ModelQueryModes, schemaMap, chatCtx, queryContext, modelId,restrictiveModels)) { log.info("selectModel from ChatContext [{}]", modelId); return modelId; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/ModelResolver.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/ModelResolver.java index cdabc8d04..7bb68ee66 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/ModelResolver.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/function/ModelResolver.java @@ -4,9 +4,10 @@ package com.tencent.supersonic.chat.parser.function; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import java.util.List; +import java.util.Set; public interface ModelResolver { - Long resolve(QueryContext queryContext, ChatContext chatCtx, List restrictiveModels); + Long resolve(QueryContext queryContext, ChatContext chatCtx, Set restrictiveModels); } \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/DSLParseResult.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/DSLParseResult.java deleted file mode 100644 index f648ee47c..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/DSLParseResult.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.tencent.supersonic.chat.parser.llm; - -import com.tencent.supersonic.chat.plugin.PluginParseResult; -import com.tencent.supersonic.chat.query.dsl.LLMResp; -import lombok.Data; - -@Data -public class DSLParseResult extends PluginParseResult { - - private LLMResp llmResp; -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLDateHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLDateHelper.java new file mode 100644 index 000000000..46dac8beb --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLDateHelper.java @@ -0,0 +1,20 @@ +package com.tencent.supersonic.chat.parser.llm.dsl; + +import com.tencent.supersonic.common.util.DateUtils; + +public class DSLDateHelper { + + public static String getCurrentDate(Long modelId) { + return DateUtils.getBeforeDate(4); +// ChatConfigFilter filter = new ChatConfigFilter(); +// filter.setModelId(modelId); +// +// List configResps = ContextUtils.getBean(ConfigService.class).search(filter, null); +// if (CollectionUtils.isEmpty(configResps)) { +// return +// } +// ChatConfigResp chatConfigResp = configResps.get(0); +// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get + + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLParseResult.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLParseResult.java new file mode 100644 index 000000000..6a9590ff2 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/DSLParseResult.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.parser.llm.dsl; + +import com.tencent.supersonic.chat.agent.tool.DslTool; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.plugin.PluginParseResult; +import com.tencent.supersonic.chat.query.dsl.LLMResp; +import lombok.Data; + +@Data +public class DSLParseResult { + + private LLMResp llmResp; + + private QueryReq request; + + private DslTool dslTool; +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/LLMDSLParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDSLParser.java similarity index 82% rename from chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/LLMDSLParser.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDSLParser.java index 72c2bace9..d89e2b551 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/LLMDSLParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDSLParser.java @@ -1,5 +1,10 @@ -package com.tencent.supersonic.chat.parser.llm; +package com.tencent.supersonic.chat.parser.llm.dsl; +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.DslTool; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; @@ -11,26 +16,26 @@ import com.tencent.supersonic.chat.api.pojo.SemanticSchema; import com.tencent.supersonic.chat.config.LLMConfig; import com.tencent.supersonic.chat.parser.SatisfactionChecker; import com.tencent.supersonic.chat.parser.function.ModelResolver; -import com.tencent.supersonic.chat.plugin.Plugin; -import com.tencent.supersonic.chat.plugin.PluginManager; import com.tencent.supersonic.chat.query.QueryManager; -import com.tencent.supersonic.chat.query.dsl.DSLBuilder; import com.tencent.supersonic.chat.query.dsl.DSLQuery; import com.tencent.supersonic.chat.query.dsl.LLMReq; import com.tencent.supersonic.chat.query.dsl.LLMReq.ElementValue; import com.tencent.supersonic.chat.query.dsl.LLMResp; +import com.tencent.supersonic.chat.query.dsl.optimizer.BaseDSLOptimizer; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; +import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.DateUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.knowledge.service.SchemaService; import java.util.ArrayList; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; @@ -56,30 +61,30 @@ public class LLMDSLParser implements SemanticParser { queryCtx.getRequest().getQueryText()); return; } - List dslPlugins = PluginManager.getPlugins().stream() - .filter(plugin -> DSLQuery.QUERY_MODE.equalsIgnoreCase(plugin.getType())) - .collect(Collectors.toList()); - - if (CollectionUtils.isEmpty(dslPlugins)) { - return; - } - Plugin plugin = dslPlugins.get(0); - List dslModels = plugin.getModelList(); + List dslTools = getDslTools(queryCtx.getRequest().getAgentId()); + Set distinctModelIds = dslTools.stream().map(DslTool::getModelIds) + .flatMap(Collection::stream) + .collect(Collectors.toSet()); try { ModelResolver modelResolver = ComponentFactory.getModelResolver(); - Long modelId = modelResolver.resolve(queryCtx, chatCtx, dslModels); - log.info("resolve modelId:{},dslModels:{}", modelId, dslModels); + Long modelId = modelResolver.resolve(queryCtx, chatCtx, distinctModelIds); + log.info("resolve modelId:{},dslModels:{}", modelId, distinctModelIds); - if (Objects.isNull(modelId)) { + if (Objects.isNull(modelId) || modelId <= 0) { return; } - + Optional dslToolOptional = dslTools.stream().filter(tool -> + tool.getModelIds().contains(modelId)).findFirst(); + if (!dslToolOptional.isPresent()) { + log.info("no dsl tool in this agent, skip dsl parser"); + return; + } + DslTool dslTool = dslToolOptional.get(); LLMResp llmResp = requestLLM(queryCtx, modelId); if (Objects.isNull(llmResp)) { return; } - PluginSemanticQuery semanticQuery = QueryManager.createPluginQuery(DSLQuery.QUERY_MODE); SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); @@ -89,10 +94,12 @@ public class LLMDSLParser implements SemanticParser { DSLParseResult dslParseResult = new DSLParseResult(); dslParseResult.setRequest(queryCtx.getRequest()); dslParseResult.setLlmResp(llmResp); - dslParseResult.setPlugin(plugin); + dslParseResult.setDslTool(dslToolOptional.get()); Map properties = new HashMap<>(); properties.put(Constants.CONTEXT, dslParseResult); + properties.put("type", "internal"); + properties.put("name", dslTool.getName()); parseInfo.setProperties(properties); parseInfo.setScore(FUNCTION_BONUS_THRESHOLD); parseInfo.setQueryMode(semanticQuery.getQueryMode()); @@ -126,13 +133,13 @@ public class LLMDSLParser implements SemanticParser { llmSchema.setModelName(modelIdToName.get(modelId)); llmSchema.setDomainName(modelIdToName.get(modelId)); List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema); - fieldNameList.add(DSLBuilder.DATA_Field); + fieldNameList.add(BaseDSLOptimizer.DATE_FIELD); llmSchema.setFieldNameList(fieldNameList); llmReq.setSchema(llmSchema); List linking = new ArrayList<>(); linking.addAll(getValueList(queryCtx, modelId, semanticSchema)); llmReq.setLinking(linking); - String currentDate = getCurrentDate(modelId); + String currentDate = DSLDateHelper.getCurrentDate(modelId); llmReq.setCurrentDate(currentDate); log.info("requestLLM request, modelId:{},llmReq:{}", modelId, llmReq); @@ -156,21 +163,6 @@ public class LLMDSLParser implements SemanticParser { return null; } - - private String getCurrentDate(Long modelId) { - return DateUtils.getBeforeDate(4); -// ChatConfigFilter filter = new ChatConfigFilter(); -// filter.setModelId(modelId); -// -// List configResps = ContextUtils.getBean(ConfigService.class).search(filter, null); -// if (CollectionUtils.isEmpty(configResps)) { -// return -// } -// ChatConfigResp chatConfigResp = configResps.get(0); -// chatConfigResp.getChatDetailConfig().getChatDefaultConfig().get - - } - private List getValueList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { Map itemIdToName = getItemIdToName(modelId, semanticSchema); @@ -228,4 +220,18 @@ public class LLMDSLParser implements SemanticParser { .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } + private List getDslTools(Integer agentId) { + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent agent = agentService.getAgent(agentId); + if (agent == null) { + return Lists.newArrayList(); + } + List tools = agent.getTools(AgentToolType.DSL); + if (CollectionUtils.isEmpty(tools)) { + return Lists.newArrayList(); + } + return tools.stream().map(tool -> JSONObject.parseObject(tool, DslTool.class)) + .collect(Collectors.toList()); + } + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java new file mode 100644 index 000000000..1ceab6c25 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricInterpretParser.java @@ -0,0 +1,144 @@ +package com.tencent.supersonic.chat.parser.llm.interpret; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Sets; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; +import com.tencent.supersonic.chat.api.component.SemanticLayer; +import com.tencent.supersonic.chat.api.component.SemanticParser; +import com.tencent.supersonic.chat.api.pojo.*; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.parser.SatisfactionChecker; +import com.tencent.supersonic.chat.query.metricInterpret.MetricInterpretQuery; +import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; +import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.chat.utils.ComponentFactory; +import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; +import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; +import java.util.*; +import java.util.stream.Collectors; + +@Slf4j +public class MetricInterpretParser implements SemanticParser { + + @Override + public void parse(QueryContext queryContext, ChatContext chatContext) { + if (SatisfactionChecker.check(queryContext)) { + log.info("skip MetricInterpretParser"); + return; + } + Map metricInterpretToolMap = getMetricInterpretTools(queryContext.getRequest().getAgentId()); + log.info("metric interpret tool : {}", metricInterpretToolMap); + if (CollectionUtils.isEmpty(metricInterpretToolMap)) { + return; + } + Map> elementMatches = queryContext.getMapInfo().getModelElementMatches(); + for (Long modelId : elementMatches.keySet()) { + MetricInterpretTool metricInterpretTool = metricInterpretToolMap.get(modelId); + if (metricInterpretTool == null) { + continue; + } + if (CollectionUtils.isEmpty(elementMatches.get(modelId))) { + continue; + } + List metricOptions = metricInterpretTool.getMetricOptions(); + if (!CollectionUtils.isEmpty(metricOptions)) { + List metricIds = metricOptions.stream().map(MetricOption::getMetricId).collect(Collectors.toList()); + buildQuery(modelId, queryContext, metricIds, elementMatches.get(modelId), metricInterpretTool.getName()); + } + } + } + + private void buildQuery(Long modelId, QueryContext queryContext, + List metricIds, List schemaElementMatches, String toolName) { + PluginSemanticQuery metricInterpretQuery = QueryManager.createPluginQuery(MetricInterpretQuery.QUERY_MODE); + Set metrics = getMetrics(metricIds, modelId); + SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(modelId, queryContext.getRequest(), + metrics, schemaElementMatches, toolName); + semanticParseInfo.setQueryMode(metricInterpretQuery.getQueryMode()); + semanticParseInfo.getProperties().put("queryText", queryContext.getRequest().getQueryText()); + metricInterpretQuery.setParseInfo(semanticParseInfo); + queryContext.getCandidateQueries().add(metricInterpretQuery); + } + + public Set getMetrics(List metricIds, Long modelId) { + SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); + ModelSchema modelSchema = semanticLayer.getModelSchema(modelId, true); + Set metrics = modelSchema.getMetrics(); + return metrics.stream().filter(schemaElement -> metricIds.contains(schemaElement.getId())) + .collect(Collectors.toSet()); + } + + private Map getMetricInterpretTools(Integer agentId) { + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent agent = agentService.getAgent(agentId); + if (agent == null) { + return new HashMap<>(); + } + List tools= agent.getTools(AgentToolType.INTERPRET); + if (CollectionUtils.isEmpty(tools)) { + return new HashMap<>(); + } + List metricInterpretTools = tools.stream().map(tool -> + JSONObject.parseObject(tool, MetricInterpretTool.class)) + .filter(tool -> !CollectionUtils.isEmpty(tool.getMetricOptions())) + .collect(Collectors.toList()); + Map metricInterpretToolMap = new HashMap<>(); + for (MetricInterpretTool metricInterpretTool : metricInterpretTools) { + metricInterpretToolMap.putIfAbsent(metricInterpretTool.getModelId(), + metricInterpretTool); + } + return metricInterpretToolMap; + } + + private SemanticParseInfo buildSemanticParseInfo(Long modelId, QueryReq queryReq, Set metrics, + List schemaElementMatches, String toolName) { + SchemaElement Model = new SchemaElement(); + Model.setModel(modelId); + Model.setId(modelId); + SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); + semanticParseInfo.setMetrics(metrics); + SchemaElement dimension = new SchemaElement(); + dimension.setBizName(TimeDimensionEnum.DAY.getName()); + semanticParseInfo.setDimensions(Sets.newHashSet(dimension)); + semanticParseInfo.setElementMatches(schemaElementMatches); + semanticParseInfo.setModel(Model); + semanticParseInfo.setScore(queryReq.getQueryText().length()); + DateConf dateConf = new DateConf(); + dateConf.setDateMode(DateConf.DateMode.RECENT); + dateConf.setUnit(15); + semanticParseInfo.setDateInfo(dateConf); + Map properties = new HashMap<>(); + properties.put("type", "internal"); + properties.put("name", toolName); + semanticParseInfo.setProperties(properties); + fillSemanticParseInfo(semanticParseInfo); + return semanticParseInfo; + } + + private void fillSemanticParseInfo(SemanticParseInfo semanticParseInfo) { + List schemaElementMatches = semanticParseInfo.getElementMatches(); + if (!CollectionUtils.isEmpty(schemaElementMatches)) { + schemaElementMatches.stream().filter(schemaElementMatch -> + SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) + || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) + .forEach(schemaElementMatch -> { + QueryFilter queryFilter = new QueryFilter(); + queryFilter.setValue(schemaElementMatch.getWord()); + queryFilter.setElementID(schemaElementMatch.getElement().getId()); + queryFilter.setName(schemaElementMatch.getElement().getName()); + queryFilter.setOperator(FilterOperatorEnum.EQUALS); + queryFilter.setBizName(schemaElementMatch.getElement().getBizName()); + semanticParseInfo.getDimensionFilters().add(queryFilter); + }); + } + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricOption.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricOption.java new file mode 100644 index 000000000..46811ff97 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/interpret/MetricOption.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.chat.parser.llm.interpret; + + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class MetricOption { + + private Long metricId; +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/LLMTimeEnhancementParse.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/time/LLMTimeEnhancementParse.java similarity index 69% rename from chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/LLMTimeEnhancementParse.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/time/LLMTimeEnhancementParse.java index b346efb63..95ff25b0c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/LLMTimeEnhancementParse.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/time/LLMTimeEnhancementParse.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.parser.llm; +package com.tencent.supersonic.chat.parser.llm.time; import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; @@ -6,8 +6,8 @@ import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.utils.ChatGptHelper; import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.util.ChatGptHelper; import com.tencent.supersonic.common.util.ContextUtils; import lombok.extern.slf4j.Slf4j; @@ -17,20 +17,20 @@ public class LLMTimeEnhancementParse implements SemanticParser { @Override public void parse(QueryContext queryContext, ChatContext chatContext) { - log.info("before queryContext:{},chatContext:{}", queryContext, chatContext); + log.info("before queryContext:{},chatContext:{}",queryContext,chatContext); ChatGptHelper chatGptHelper = ContextUtils.getBean(ChatGptHelper.class); - String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText()); try { + String inferredTime = chatGptHelper.inferredTime(queryContext.getRequest().getQueryText()); if (!queryContext.getCandidateQueries().isEmpty()) { for (SemanticQuery query : queryContext.getCandidateQueries()) { DateConf dateInfo = query.getParseInfo().getDateInfo(); JSONObject jsonObject = JSON.parseObject(inferredTime); - if (jsonObject.containsKey("date")) { + if (jsonObject.containsKey("date")){ dateInfo.setDateMode(DateConf.DateMode.BETWEEN); dateInfo.setStartDate(jsonObject.getString("date")); dateInfo.setEndDate(jsonObject.getString("date")); query.getParseInfo().setDateInfo(dateInfo); - } else if (jsonObject.containsKey("start")) { + }else if (jsonObject.containsKey("start")){ dateInfo.setDateMode(DateConf.DateMode.BETWEEN); dateInfo.setStartDate(jsonObject.getString("start")); dateInfo.setEndDate(jsonObject.getString("end")); @@ -38,12 +38,11 @@ public class LLMTimeEnhancementParse implements SemanticParser { } } } - } catch (Exception exception) { - log.error("{} parse error,this reason is:{}", LLMTimeEnhancementParse.class.getSimpleName(), - (Object) exception.getStackTrace()); + }catch (Exception exception){ + log.error("{} parse error,this reason is:{}",LLMTimeEnhancementParse.class.getSimpleName(), (Object) exception.getStackTrace()); } - log.info("after queryContext:{},chatContext:{}", queryContext, chatContext); + log.info("{} after queryContext:{},chatContext:{}",LLMTimeEnhancementParse.class.getSimpleName(),queryContext,chatContext); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java new file mode 100644 index 000000000..157c5c606 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/AgentCheckParser.java @@ -0,0 +1,63 @@ +package com.tencent.supersonic.chat.parser.rule; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; +import com.tencent.supersonic.chat.api.component.SemanticParser; +import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.ChatContext; +import com.tencent.supersonic.chat.api.pojo.QueryContext; +import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.common.util.ContextUtils; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.collections.CollectionUtils; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; + +@Slf4j +public class AgentCheckParser implements SemanticParser { + + @Override + public void parse(QueryContext queryContext, ChatContext chatContext) { + List queries = queryContext.getCandidateQueries(); + agentCanSupport(queryContext.getRequest().getAgentId(), queries); + } + + private void agentCanSupport(Integer agentId, List queries) { + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent agent = agentService.getAgent(agentId); + if (agent == null) { + return; + } + List queryModes = getRuleTools(agentId).stream().map(RuleQueryTool::getQueryModes) + .flatMap(Collection::stream).collect(Collectors.toList()); + if (CollectionUtils.isEmpty(queries)) { + queries.clear(); + return; + } + log.info("queries resolved:{} {}", agent.getName(), + queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); + queries.removeIf(query -> + !queryModes.contains(query.getQueryMode())); + log.info("rule queries witch can be supported by agent :{} {}", agent.getName(), + queries.stream().map(SemanticQuery::getQueryMode).collect(Collectors.toList())); + } + + private static List getRuleTools(Integer agentId) { + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent agent = agentService.getAgent(agentId); + if (agent == null) { + return Lists.newArrayList(); + } + List tools = agent.getTools(AgentToolType.RULE); + if (CollectionUtils.isEmpty(tools)) { + return Lists.newArrayList(); + } + return tools.stream().map(tool -> JSONObject.parseObject(tool, RuleQueryTool.class)) + .collect(Collectors.toList()); + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/QueryModeParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/QueryModeParser.java index 1644f3ea2..983d3cb62 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/QueryModeParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/rule/QueryModeParser.java @@ -1,12 +1,9 @@ package com.tencent.supersonic.chat.parser.rule; import com.tencent.supersonic.chat.api.component.SemanticParser; -import com.tencent.supersonic.chat.api.pojo.ChatContext; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; -import java.util.List; +import java.util.*; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -15,12 +12,10 @@ public class QueryModeParser implements SemanticParser { @Override public void parse(QueryContext queryContext, ChatContext chatContext) { SchemaMapInfo mapInfo = queryContext.getMapInfo(); - // iterate all schemaElementMatches to resolve semantic query for (Long modelId : mapInfo.getMatchedModels()) { List elementMatches = mapInfo.getMatchedElements(modelId); List queries = RuleSemanticQuery.resolve(elementMatches, queryContext); - for (RuleSemanticQuery query : queries) { query.fillParseInfo(modelId, queryContext, chatContext); queryContext.getCandidateQueries().add(query); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDO.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDO.java new file mode 100644 index 000000000..9777a9492 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDO.java @@ -0,0 +1,236 @@ +package com.tencent.supersonic.chat.persistence.dataobject; + +import java.util.Date; + +public class AgentDO { + /** + * + */ + private Integer id; + + /** + * + */ + private String name; + + /** + * + */ + private String description; + + /** + * 0 offline, 1 online + */ + private Integer status; + + /** + * + */ + private String examples; + + /** + * + */ + private String config; + + /** + * + */ + private String createdBy; + + /** + * + */ + private Date createdAt; + + /** + * + */ + private String updatedBy; + + /** + * + */ + private Date updatedAt; + + /** + * + */ + private Integer enableSearch; + + /** + * + * @return id + */ + public Integer getId() { + return id; + } + + /** + * + * @param id + */ + public void setId(Integer id) { + this.id = id; + } + + /** + * + * @return name + */ + public String getName() { + return name; + } + + /** + * + * @param name + */ + public void setName(String name) { + this.name = name == null ? null : name.trim(); + } + + /** + * + * @return description + */ + public String getDescription() { + return description; + } + + /** + * + * @param description + */ + public void setDescription(String description) { + this.description = description == null ? null : description.trim(); + } + + /** + * 0 offline, 1 online + * @return status 0 offline, 1 online + */ + public Integer getStatus() { + return status; + } + + /** + * 0 offline, 1 online + * @param status 0 offline, 1 online + */ + public void setStatus(Integer status) { + this.status = status; + } + + /** + * + * @return examples + */ + public String getExamples() { + return examples; + } + + /** + * + * @param examples + */ + public void setExamples(String examples) { + this.examples = examples == null ? null : examples.trim(); + } + + /** + * + * @return config + */ + public String getConfig() { + return config; + } + + /** + * + * @param config + */ + public void setConfig(String config) { + this.config = config == null ? null : config.trim(); + } + + /** + * + * @return created_by + */ + public String getCreatedBy() { + return createdBy; + } + + /** + * + * @param createdBy + */ + public void setCreatedBy(String createdBy) { + this.createdBy = createdBy == null ? null : createdBy.trim(); + } + + /** + * + * @return created_at + */ + public Date getCreatedAt() { + return createdAt; + } + + /** + * + * @param createdAt + */ + public void setCreatedAt(Date createdAt) { + this.createdAt = createdAt; + } + + /** + * + * @return updated_by + */ + public String getUpdatedBy() { + return updatedBy; + } + + /** + * + * @param updatedBy + */ + public void setUpdatedBy(String updatedBy) { + this.updatedBy = updatedBy == null ? null : updatedBy.trim(); + } + + /** + * + * @return updated_at + */ + public Date getUpdatedAt() { + return updatedAt; + } + + /** + * + * @param updatedAt + */ + public void setUpdatedAt(Date updatedAt) { + this.updatedAt = updatedAt; + } + + /** + * + * @return enable_search + */ + public Integer getEnableSearch() { + return enableSearch; + } + + /** + * + * @param enableSearch + */ + public void setEnableSearch(Integer enableSearch) { + this.enableSearch = enableSearch; + } +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDOExample.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDOExample.java new file mode 100644 index 000000000..fb4bdda0f --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/dataobject/AgentDOExample.java @@ -0,0 +1,1025 @@ +package com.tencent.supersonic.chat.persistence.dataobject; + +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + +public class AgentDOExample { + /** + * s2_agent + */ + protected String orderByClause; + + /** + * s2_agent + */ + protected boolean distinct; + + /** + * s2_agent + */ + protected List oredCriteria; + + /** + * s2_agent + */ + protected Integer limitStart; + + /** + * s2_agent + */ + protected Integer limitEnd; + + /** + * + * @mbg.generated + */ + public AgentDOExample() { + oredCriteria = new ArrayList(); + } + + /** + * + * @mbg.generated + */ + public void setOrderByClause(String orderByClause) { + this.orderByClause = orderByClause; + } + + /** + * + * @mbg.generated + */ + public String getOrderByClause() { + return orderByClause; + } + + /** + * + * @mbg.generated + */ + public void setDistinct(boolean distinct) { + this.distinct = distinct; + } + + /** + * + * @mbg.generated + */ + public boolean isDistinct() { + return distinct; + } + + /** + * + * @mbg.generated + */ + public List getOredCriteria() { + return oredCriteria; + } + + /** + * + * @mbg.generated + */ + public void or(Criteria criteria) { + oredCriteria.add(criteria); + } + + /** + * + * @mbg.generated + */ + public Criteria or() { + Criteria criteria = createCriteriaInternal(); + oredCriteria.add(criteria); + return criteria; + } + + /** + * + * @mbg.generated + */ + public Criteria createCriteria() { + Criteria criteria = createCriteriaInternal(); + if (oredCriteria.size() == 0) { + oredCriteria.add(criteria); + } + return criteria; + } + + /** + * + * @mbg.generated + */ + protected Criteria createCriteriaInternal() { + Criteria criteria = new Criteria(); + return criteria; + } + + /** + * + * @mbg.generated + */ + public void clear() { + oredCriteria.clear(); + orderByClause = null; + distinct = false; + } + + /** + * + * @mbg.generated + */ + public void setLimitStart(Integer limitStart) { + this.limitStart=limitStart; + } + + /** + * + * @mbg.generated + */ + public Integer getLimitStart() { + return limitStart; + } + + /** + * + * @mbg.generated + */ + public void setLimitEnd(Integer limitEnd) { + this.limitEnd=limitEnd; + } + + /** + * + * @mbg.generated + */ + public Integer getLimitEnd() { + return limitEnd; + } + + /** + * s2_agent null + */ + protected abstract static class GeneratedCriteria { + protected List criteria; + + protected GeneratedCriteria() { + super(); + criteria = new ArrayList(); + } + + public boolean isValid() { + return criteria.size() > 0; + } + + public List getAllCriteria() { + return criteria; + } + + public List getCriteria() { + return criteria; + } + + protected void addCriterion(String condition) { + if (condition == null) { + throw new RuntimeException("Value for condition cannot be null"); + } + criteria.add(new Criterion(condition)); + } + + protected void addCriterion(String condition, Object value, String property) { + if (value == null) { + throw new RuntimeException("Value for " + property + " cannot be null"); + } + criteria.add(new Criterion(condition, value)); + } + + protected void addCriterion(String condition, Object value1, Object value2, String property) { + if (value1 == null || value2 == null) { + throw new RuntimeException("Between values for " + property + " cannot be null"); + } + criteria.add(new Criterion(condition, value1, value2)); + } + + public Criteria andIdIsNull() { + addCriterion("id is null"); + return (Criteria) this; + } + + public Criteria andIdIsNotNull() { + addCriterion("id is not null"); + return (Criteria) this; + } + + public Criteria andIdEqualTo(Integer value) { + addCriterion("id =", value, "id"); + return (Criteria) this; + } + + public Criteria andIdNotEqualTo(Integer value) { + addCriterion("id <>", value, "id"); + return (Criteria) this; + } + + public Criteria andIdGreaterThan(Integer value) { + addCriterion("id >", value, "id"); + return (Criteria) this; + } + + public Criteria andIdGreaterThanOrEqualTo(Integer value) { + addCriterion("id >=", value, "id"); + return (Criteria) this; + } + + public Criteria andIdLessThan(Integer value) { + addCriterion("id <", value, "id"); + return (Criteria) this; + } + + public Criteria andIdLessThanOrEqualTo(Integer value) { + addCriterion("id <=", value, "id"); + return (Criteria) this; + } + + public Criteria andIdIn(List values) { + addCriterion("id in", values, "id"); + return (Criteria) this; + } + + public Criteria andIdNotIn(List values) { + addCriterion("id not in", values, "id"); + return (Criteria) this; + } + + public Criteria andIdBetween(Integer value1, Integer value2) { + addCriterion("id between", value1, value2, "id"); + return (Criteria) this; + } + + public Criteria andIdNotBetween(Integer value1, Integer value2) { + addCriterion("id not between", value1, value2, "id"); + return (Criteria) this; + } + + public Criteria andNameIsNull() { + addCriterion("name is null"); + return (Criteria) this; + } + + public Criteria andNameIsNotNull() { + addCriterion("name is not null"); + return (Criteria) this; + } + + public Criteria andNameEqualTo(String value) { + addCriterion("name =", value, "name"); + return (Criteria) this; + } + + public Criteria andNameNotEqualTo(String value) { + addCriterion("name <>", value, "name"); + return (Criteria) this; + } + + public Criteria andNameGreaterThan(String value) { + addCriterion("name >", value, "name"); + return (Criteria) this; + } + + public Criteria andNameGreaterThanOrEqualTo(String value) { + addCriterion("name >=", value, "name"); + return (Criteria) this; + } + + public Criteria andNameLessThan(String value) { + addCriterion("name <", value, "name"); + return (Criteria) this; + } + + public Criteria andNameLessThanOrEqualTo(String value) { + addCriterion("name <=", value, "name"); + return (Criteria) this; + } + + public Criteria andNameLike(String value) { + addCriterion("name like", value, "name"); + return (Criteria) this; + } + + public Criteria andNameNotLike(String value) { + addCriterion("name not like", value, "name"); + return (Criteria) this; + } + + public Criteria andNameIn(List values) { + addCriterion("name in", values, "name"); + return (Criteria) this; + } + + public Criteria andNameNotIn(List values) { + addCriterion("name not in", values, "name"); + return (Criteria) this; + } + + public Criteria andNameBetween(String value1, String value2) { + addCriterion("name between", value1, value2, "name"); + return (Criteria) this; + } + + public Criteria andNameNotBetween(String value1, String value2) { + addCriterion("name not between", value1, value2, "name"); + return (Criteria) this; + } + + public Criteria andDescriptionIsNull() { + addCriterion("description is null"); + return (Criteria) this; + } + + public Criteria andDescriptionIsNotNull() { + addCriterion("description is not null"); + return (Criteria) this; + } + + public Criteria andDescriptionEqualTo(String value) { + addCriterion("description =", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionNotEqualTo(String value) { + addCriterion("description <>", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionGreaterThan(String value) { + addCriterion("description >", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionGreaterThanOrEqualTo(String value) { + addCriterion("description >=", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionLessThan(String value) { + addCriterion("description <", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionLessThanOrEqualTo(String value) { + addCriterion("description <=", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionLike(String value) { + addCriterion("description like", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionNotLike(String value) { + addCriterion("description not like", value, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionIn(List values) { + addCriterion("description in", values, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionNotIn(List values) { + addCriterion("description not in", values, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionBetween(String value1, String value2) { + addCriterion("description between", value1, value2, "description"); + return (Criteria) this; + } + + public Criteria andDescriptionNotBetween(String value1, String value2) { + addCriterion("description not between", value1, value2, "description"); + return (Criteria) this; + } + + public Criteria andStatusIsNull() { + addCriterion("status is null"); + return (Criteria) this; + } + + public Criteria andStatusIsNotNull() { + addCriterion("status is not null"); + return (Criteria) this; + } + + public Criteria andStatusEqualTo(Integer value) { + addCriterion("status =", value, "status"); + return (Criteria) this; + } + + public Criteria andStatusNotEqualTo(Integer value) { + addCriterion("status <>", value, "status"); + return (Criteria) this; + } + + public Criteria andStatusGreaterThan(Integer value) { + addCriterion("status >", value, "status"); + return (Criteria) this; + } + + public Criteria andStatusGreaterThanOrEqualTo(Integer value) { + addCriterion("status >=", value, "status"); + return (Criteria) this; + } + + public Criteria andStatusLessThan(Integer value) { + addCriterion("status <", value, "status"); + return (Criteria) this; + } + + public Criteria andStatusLessThanOrEqualTo(Integer value) { + addCriterion("status <=", value, "status"); + return (Criteria) this; + } + + public Criteria andStatusIn(List values) { + addCriterion("status in", values, "status"); + return (Criteria) this; + } + + public Criteria andStatusNotIn(List values) { + addCriterion("status not in", values, "status"); + return (Criteria) this; + } + + public Criteria andStatusBetween(Integer value1, Integer value2) { + addCriterion("status between", value1, value2, "status"); + return (Criteria) this; + } + + public Criteria andStatusNotBetween(Integer value1, Integer value2) { + addCriterion("status not between", value1, value2, "status"); + return (Criteria) this; + } + + public Criteria andExamplesIsNull() { + addCriterion("examples is null"); + return (Criteria) this; + } + + public Criteria andExamplesIsNotNull() { + addCriterion("examples is not null"); + return (Criteria) this; + } + + public Criteria andExamplesEqualTo(String value) { + addCriterion("examples =", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesNotEqualTo(String value) { + addCriterion("examples <>", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesGreaterThan(String value) { + addCriterion("examples >", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesGreaterThanOrEqualTo(String value) { + addCriterion("examples >=", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesLessThan(String value) { + addCriterion("examples <", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesLessThanOrEqualTo(String value) { + addCriterion("examples <=", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesLike(String value) { + addCriterion("examples like", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesNotLike(String value) { + addCriterion("examples not like", value, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesIn(List values) { + addCriterion("examples in", values, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesNotIn(List values) { + addCriterion("examples not in", values, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesBetween(String value1, String value2) { + addCriterion("examples between", value1, value2, "examples"); + return (Criteria) this; + } + + public Criteria andExamplesNotBetween(String value1, String value2) { + addCriterion("examples not between", value1, value2, "examples"); + return (Criteria) this; + } + + public Criteria andConfigIsNull() { + addCriterion("config is null"); + return (Criteria) this; + } + + public Criteria andConfigIsNotNull() { + addCriterion("config is not null"); + return (Criteria) this; + } + + public Criteria andConfigEqualTo(String value) { + addCriterion("config =", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigNotEqualTo(String value) { + addCriterion("config <>", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigGreaterThan(String value) { + addCriterion("config >", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigGreaterThanOrEqualTo(String value) { + addCriterion("config >=", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigLessThan(String value) { + addCriterion("config <", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigLessThanOrEqualTo(String value) { + addCriterion("config <=", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigLike(String value) { + addCriterion("config like", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigNotLike(String value) { + addCriterion("config not like", value, "config"); + return (Criteria) this; + } + + public Criteria andConfigIn(List values) { + addCriterion("config in", values, "config"); + return (Criteria) this; + } + + public Criteria andConfigNotIn(List values) { + addCriterion("config not in", values, "config"); + return (Criteria) this; + } + + public Criteria andConfigBetween(String value1, String value2) { + addCriterion("config between", value1, value2, "config"); + return (Criteria) this; + } + + public Criteria andConfigNotBetween(String value1, String value2) { + addCriterion("config not between", value1, value2, "config"); + return (Criteria) this; + } + + public Criteria andCreatedByIsNull() { + addCriterion("created_by is null"); + return (Criteria) this; + } + + public Criteria andCreatedByIsNotNull() { + addCriterion("created_by is not null"); + return (Criteria) this; + } + + public Criteria andCreatedByEqualTo(String value) { + addCriterion("created_by =", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByNotEqualTo(String value) { + addCriterion("created_by <>", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByGreaterThan(String value) { + addCriterion("created_by >", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByGreaterThanOrEqualTo(String value) { + addCriterion("created_by >=", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByLessThan(String value) { + addCriterion("created_by <", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByLessThanOrEqualTo(String value) { + addCriterion("created_by <=", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByLike(String value) { + addCriterion("created_by like", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByNotLike(String value) { + addCriterion("created_by not like", value, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByIn(List values) { + addCriterion("created_by in", values, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByNotIn(List values) { + addCriterion("created_by not in", values, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByBetween(String value1, String value2) { + addCriterion("created_by between", value1, value2, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedByNotBetween(String value1, String value2) { + addCriterion("created_by not between", value1, value2, "createdBy"); + return (Criteria) this; + } + + public Criteria andCreatedAtIsNull() { + addCriterion("created_at is null"); + return (Criteria) this; + } + + public Criteria andCreatedAtIsNotNull() { + addCriterion("created_at is not null"); + return (Criteria) this; + } + + public Criteria andCreatedAtEqualTo(Date value) { + addCriterion("created_at =", value, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtNotEqualTo(Date value) { + addCriterion("created_at <>", value, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtGreaterThan(Date value) { + addCriterion("created_at >", value, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtGreaterThanOrEqualTo(Date value) { + addCriterion("created_at >=", value, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtLessThan(Date value) { + addCriterion("created_at <", value, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtLessThanOrEqualTo(Date value) { + addCriterion("created_at <=", value, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtIn(List values) { + addCriterion("created_at in", values, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtNotIn(List values) { + addCriterion("created_at not in", values, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtBetween(Date value1, Date value2) { + addCriterion("created_at between", value1, value2, "createdAt"); + return (Criteria) this; + } + + public Criteria andCreatedAtNotBetween(Date value1, Date value2) { + addCriterion("created_at not between", value1, value2, "createdAt"); + return (Criteria) this; + } + + public Criteria andUpdatedByIsNull() { + addCriterion("updated_by is null"); + return (Criteria) this; + } + + public Criteria andUpdatedByIsNotNull() { + addCriterion("updated_by is not null"); + return (Criteria) this; + } + + public Criteria andUpdatedByEqualTo(String value) { + addCriterion("updated_by =", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByNotEqualTo(String value) { + addCriterion("updated_by <>", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByGreaterThan(String value) { + addCriterion("updated_by >", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByGreaterThanOrEqualTo(String value) { + addCriterion("updated_by >=", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByLessThan(String value) { + addCriterion("updated_by <", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByLessThanOrEqualTo(String value) { + addCriterion("updated_by <=", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByLike(String value) { + addCriterion("updated_by like", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByNotLike(String value) { + addCriterion("updated_by not like", value, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByIn(List values) { + addCriterion("updated_by in", values, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByNotIn(List values) { + addCriterion("updated_by not in", values, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByBetween(String value1, String value2) { + addCriterion("updated_by between", value1, value2, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedByNotBetween(String value1, String value2) { + addCriterion("updated_by not between", value1, value2, "updatedBy"); + return (Criteria) this; + } + + public Criteria andUpdatedAtIsNull() { + addCriterion("updated_at is null"); + return (Criteria) this; + } + + public Criteria andUpdatedAtIsNotNull() { + addCriterion("updated_at is not null"); + return (Criteria) this; + } + + public Criteria andUpdatedAtEqualTo(Date value) { + addCriterion("updated_at =", value, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtNotEqualTo(Date value) { + addCriterion("updated_at <>", value, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtGreaterThan(Date value) { + addCriterion("updated_at >", value, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtGreaterThanOrEqualTo(Date value) { + addCriterion("updated_at >=", value, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtLessThan(Date value) { + addCriterion("updated_at <", value, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtLessThanOrEqualTo(Date value) { + addCriterion("updated_at <=", value, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtIn(List values) { + addCriterion("updated_at in", values, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtNotIn(List values) { + addCriterion("updated_at not in", values, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtBetween(Date value1, Date value2) { + addCriterion("updated_at between", value1, value2, "updatedAt"); + return (Criteria) this; + } + + public Criteria andUpdatedAtNotBetween(Date value1, Date value2) { + addCriterion("updated_at not between", value1, value2, "updatedAt"); + return (Criteria) this; + } + + public Criteria andEnableSearchIsNull() { + addCriterion("enable_search is null"); + return (Criteria) this; + } + + public Criteria andEnableSearchIsNotNull() { + addCriterion("enable_search is not null"); + return (Criteria) this; + } + + public Criteria andEnableSearchEqualTo(Integer value) { + addCriterion("enable_search =", value, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchNotEqualTo(Integer value) { + addCriterion("enable_search <>", value, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchGreaterThan(Integer value) { + addCriterion("enable_search >", value, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchGreaterThanOrEqualTo(Integer value) { + addCriterion("enable_search >=", value, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchLessThan(Integer value) { + addCriterion("enable_search <", value, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchLessThanOrEqualTo(Integer value) { + addCriterion("enable_search <=", value, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchIn(List values) { + addCriterion("enable_search in", values, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchNotIn(List values) { + addCriterion("enable_search not in", values, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchBetween(Integer value1, Integer value2) { + addCriterion("enable_search between", value1, value2, "enableSearch"); + return (Criteria) this; + } + + public Criteria andEnableSearchNotBetween(Integer value1, Integer value2) { + addCriterion("enable_search not between", value1, value2, "enableSearch"); + return (Criteria) this; + } + } + + /** + * s2_agent + */ + public static class Criteria extends GeneratedCriteria { + + protected Criteria() { + super(); + } + } + + /** + * s2_agent null + */ + public static class Criterion { + private String condition; + + private Object value; + + private Object secondValue; + + private boolean noValue; + + private boolean singleValue; + + private boolean betweenValue; + + private boolean listValue; + + private String typeHandler; + + public String getCondition() { + return condition; + } + + public Object getValue() { + return value; + } + + public Object getSecondValue() { + return secondValue; + } + + public boolean isNoValue() { + return noValue; + } + + public boolean isSingleValue() { + return singleValue; + } + + public boolean isBetweenValue() { + return betweenValue; + } + + public boolean isListValue() { + return listValue; + } + + public String getTypeHandler() { + return typeHandler; + } + + protected Criterion(String condition) { + super(); + this.condition = condition; + this.typeHandler = null; + this.noValue = true; + } + + protected Criterion(String condition, Object value, String typeHandler) { + super(); + this.condition = condition; + this.value = value; + this.typeHandler = typeHandler; + if (value instanceof List) { + this.listValue = true; + } else { + this.singleValue = true; + } + } + + protected Criterion(String condition, Object value) { + this(condition, value, null); + } + + protected Criterion(String condition, Object value, Object secondValue, String typeHandler) { + super(); + this.condition = condition; + this.value = value; + this.secondValue = secondValue; + this.typeHandler = typeHandler; + this.betweenValue = true; + } + + protected Criterion(String condition, Object value, Object secondValue) { + this(condition, value, secondValue, null); + } + } +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/mapper/AgentDOMapper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/mapper/AgentDOMapper.java new file mode 100644 index 000000000..bfa20f7ab --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/mapper/AgentDOMapper.java @@ -0,0 +1,71 @@ +package com.tencent.supersonic.chat.persistence.mapper; + +import com.tencent.supersonic.chat.persistence.dataobject.AgentDO; +import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample; +import org.apache.ibatis.annotations.Mapper; +import org.apache.ibatis.annotations.Param; + +import java.util.List; + +@Mapper +public interface AgentDOMapper { + /** + * + * @mbg.generated + */ + long countByExample(AgentDOExample example); + + /** + * + * @mbg.generated + */ + int deleteByPrimaryKey(Integer id); + + /** + * + * @mbg.generated + */ + int insert(AgentDO record); + + /** + * + * @mbg.generated + */ + int insertSelective(AgentDO record); + + /** + * + * @mbg.generated + */ + List selectByExample(AgentDOExample example); + + /** + * + * @mbg.generated + */ + AgentDO selectByPrimaryKey(Integer id); + + /** + * + * @mbg.generated + */ + int updateByExampleSelective(@Param("record") AgentDO record, @Param("example") AgentDOExample example); + + /** + * + * @mbg.generated + */ + int updateByExample(@Param("record") AgentDO record, @Param("example") AgentDOExample example); + + /** + * + * @mbg.generated + */ + int updateByPrimaryKeySelective(AgentDO record); + + /** + * + * @mbg.generated + */ + int updateByPrimaryKey(AgentDO record); +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/AgentRepository.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/AgentRepository.java new file mode 100644 index 000000000..775209cd3 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/AgentRepository.java @@ -0,0 +1,18 @@ +package com.tencent.supersonic.chat.persistence.repository; + +import com.tencent.supersonic.chat.persistence.dataobject.AgentDO; + +import java.util.List; + +public interface AgentRepository { + + List getAgents(); + + void createAgent(AgentDO agentDO); + + void updateAgent(AgentDO agentDO); + + AgentDO getAgent(Integer id); + + void deleteAgent(Integer id); +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/AgentRepositoryImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/AgentRepositoryImpl.java new file mode 100644 index 000000000..43d7e0a2b --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/AgentRepositoryImpl.java @@ -0,0 +1,43 @@ +package com.tencent.supersonic.chat.persistence.repository.impl; + +import com.tencent.supersonic.chat.persistence.dataobject.AgentDO; +import com.tencent.supersonic.chat.persistence.dataobject.AgentDOExample; +import com.tencent.supersonic.chat.persistence.mapper.AgentDOMapper; +import com.tencent.supersonic.chat.persistence.repository.AgentRepository; +import org.springframework.stereotype.Repository; +import java.util.List; + +@Repository +public class AgentRepositoryImpl implements AgentRepository { + + private AgentDOMapper agentDOMapper; + + public AgentRepositoryImpl(AgentDOMapper agentDOMapper) { + this.agentDOMapper = agentDOMapper; + } + + @Override + public List getAgents() { + return agentDOMapper.selectByExample(new AgentDOExample()); + } + + @Override + public void createAgent(AgentDO agentDO) { + agentDOMapper.insert(agentDO); + } + + @Override + public void updateAgent(AgentDO agentDO) { + agentDOMapper.updateByPrimaryKey(agentDO); + } + + @Override + public AgentDO getAgent(Integer id) { + return agentDOMapper.selectByPrimaryKey(id); + } + + @Override + public void deleteAgent(Integer id) { + agentDOMapper.deleteByPrimaryKey(id); + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatContextRepositoryImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatContextRepositoryImpl.java index 56a70f5ba..627411496 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatContextRepositoryImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/persistence/repository/impl/ChatContextRepositoryImpl.java @@ -7,12 +7,14 @@ import com.tencent.supersonic.chat.persistence.dataobject.ChatContextDO; import com.tencent.supersonic.chat.persistence.mapper.ChatContextMapper; import com.tencent.supersonic.chat.persistence.repository.ChatContextRepository; import com.tencent.supersonic.common.util.JsonUtil; +import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Primary; import org.springframework.stereotype.Repository; @Repository @Primary +@Slf4j public class ChatContextRepositoryImpl implements ChatContextRepository { @Autowired(required = false) @@ -50,8 +52,8 @@ public class ChatContextRepositoryImpl implements ChatContextRepository { chatContext.setUser(contextDO.getUser()); chatContext.setQueryText(contextDO.getQueryText()); if (contextDO.getSemanticParse() != null && !contextDO.getSemanticParse().isEmpty()) { - SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), - SemanticParseInfo.class); + log.info("--->: {}",contextDO.getSemanticParse()); + SemanticParseInfo semanticParseInfo = JsonUtil.toObject(contextDO.getSemanticParse(), SemanticParseInfo.class); chatContext.setParseInfo(semanticParseInfo); } return chatContext; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java index 6919b4b5d..c1f1611ad 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/plugin/PluginManager.java @@ -3,11 +3,11 @@ package com.tencent.supersonic.chat.plugin; import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import com.tencent.supersonic.chat.api.pojo.QueryContext; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; -import com.tencent.supersonic.chat.api.pojo.SchemaMapInfo; +import com.tencent.supersonic.chat.agent.tool.DslTool; +import com.tencent.supersonic.chat.api.pojo.*; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.PluginTool; import com.tencent.supersonic.chat.parser.ParseMode; import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig; import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp; @@ -16,30 +16,22 @@ import com.tencent.supersonic.chat.plugin.event.PluginAddEvent; import com.tencent.supersonic.chat.plugin.event.PluginUpdateEvent; import com.tencent.supersonic.chat.query.plugin.ParamOption; import com.tencent.supersonic.chat.query.plugin.WebBase; +import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.PluginService; import com.tencent.supersonic.common.util.ContextUtils; import java.net.URI; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; +import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; +import java.util.stream.Stream; + import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.util.Strings; import org.springframework.context.event.EventListener; import org.springframework.core.ParameterizedTypeReference; -import org.springframework.http.HttpEntity; -import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; -import org.springframework.http.MediaType; -import org.springframework.http.ResponseEntity; +import org.springframework.http.*; import org.springframework.stereotype.Component; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; @@ -59,12 +51,40 @@ public class PluginManager { this.restTemplate = restTemplate; } - public static List getPlugins() { + public static List getPluginAgentCanSupport(Integer agentId) { PluginService pluginService = ContextUtils.getBean(PluginService.class); - List pluginList = pluginService.getPluginList().stream().filter(plugin -> - CollectionUtils.isNotEmpty(plugin.getModelList())).collect(Collectors.toList()); - pluginList.addAll(internalPluginMap.values()); - return new ArrayList<>(pluginList); + List plugins = pluginService.getPluginList(); + if (agentId == null) { + return plugins; + } + Agent agent = ContextUtils.getBean(AgentService.class).getAgent(agentId); + if (agent == null) { + return plugins; + } + List pluginIds = getPluginTools(agentId).stream().map(PluginTool::getPlugins) + .flatMap(Collection::stream).collect(Collectors.toList()); + if (CollectionUtils.isEmpty(pluginIds)) { + return Lists.newArrayList(); + } + plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId())) + .collect(Collectors.toList()); + log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(), + plugins.stream().map(Plugin::getName).collect(Collectors.toList())); + return plugins; + } + + private static List getPluginTools(Integer agentId) { + AgentService agentService = ContextUtils.getBean(AgentService.class); + Agent agent = agentService.getAgent(agentId); + if (agent == null) { + return Lists.newArrayList(); + } + List tools = agent.getTools(AgentToolType.PLUGIN); + if (CollectionUtils.isEmpty(tools)) { + return Lists.newArrayList(); + } + return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class)) + .collect(Collectors.toList()); } @EventListener @@ -201,17 +221,17 @@ public class PluginManager { return String.valueOf(Integer.parseInt(id) / 1000); } - public static Pair> resolve(Plugin plugin, QueryContext queryContext) { + public static Pair> resolve(Plugin plugin, QueryContext queryContext) { SchemaMapInfo schemaMapInfo = queryContext.getMapInfo(); Set pluginMatchedModel = getPluginMatchedModel(plugin, queryContext); if (CollectionUtils.isEmpty(pluginMatchedModel) && !plugin.isContainsAllModel()) { - return Pair.of(false, Lists.newArrayList()); + return Pair.of(false, Sets.newHashSet()); } List paramOptions = getSemanticOption(plugin); if (CollectionUtils.isEmpty(paramOptions)) { - return Pair.of(true, new ArrayList<>(pluginMatchedModel)); + return Pair.of(true, Sets.newHashSet()); } - List matchedModel = Lists.newArrayList(); + Set matchedModel = Sets.newHashSet(); Map> paramOptionMap = paramOptions.stream(). collect(Collectors.groupingBy(ParamOption::getModelId)); for (Long modelId : paramOptionMap.keySet()) { @@ -237,7 +257,7 @@ public class PluginManager { } } if (CollectionUtils.isEmpty(matchedModel)) { - return Pair.of(false, Lists.newArrayList()); + return Pair.of(false, Sets.newHashSet()); } return Pair.of(true, matchedModel); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/ContentInterpretQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/ContentInterpretQuery.java deleted file mode 100644 index b2d43db6c..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/ContentInterpretQuery.java +++ /dev/null @@ -1,149 +0,0 @@ -package com.tencent.supersonic.chat.query.ContentInterpret; - -import com.alibaba.fastjson.JSONObject; -import com.google.common.collect.Lists; -import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.component.SemanticLayer; -import com.tencent.supersonic.chat.api.pojo.ModelSchema; -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; -import com.tencent.supersonic.chat.api.pojo.response.QueryResult; -import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.plugin.PluginManager; -import com.tencent.supersonic.chat.plugin.PluginParseResult; -import com.tencent.supersonic.chat.query.QueryManager; -import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; -import com.tencent.supersonic.chat.utils.ComponentFactory; -import com.tencent.supersonic.common.pojo.Aggregator; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.QueryColumn; -import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; -import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.semantic.api.query.pojo.Filter; -import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import org.apache.calcite.sql.parser.SqlParseException; -import org.springframework.beans.BeanUtils; -import org.springframework.http.ResponseEntity; -import org.springframework.stereotype.Component; -import org.springframework.util.CollectionUtils; - -@Slf4j -@Component -public class ContentInterpretQuery extends PluginSemanticQuery { - - @Override - public String getQueryMode() { - return "CONTENT_INTERPRET"; - } - - public ContentInterpretQuery() { - QueryManager.register(this); - } - - @Override - public QueryResult execute(User user) throws SqlParseException { - QueryResultWithSchemaResp queryResultWithSchemaResp = queryMetric(user); - String text = generateDataText(queryResultWithSchemaResp); - Map properties = parseInfo.getProperties(); - PluginParseResult pluginParseResult = JsonUtil.toObject(JsonUtil.toString(properties.get(Constants.CONTEXT)) - , PluginParseResult.class); - String answer = fetchInterpret(pluginParseResult.getRequest().getQueryText(), text); - QueryResult queryResult = new QueryResult(); - List queryColumns = Lists.newArrayList(new QueryColumn("结果", "string", "answer")); - Map result = new HashMap<>(); - result.put("answer", answer); - List> resultList = Lists.newArrayList(); - resultList.add(result); - queryResultWithSchemaResp.setResultList(resultList); - queryResultWithSchemaResp.setColumns(queryColumns); - queryResult.setResponse(queryResultWithSchemaResp); - queryResult.setQueryMode(getQueryMode()); - queryResult.setQueryState(QueryState.SUCCESS); - return queryResult; - } - - - private QueryResultWithSchemaResp queryMetric(User user) { - SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); - QueryStructReq queryStructReq = new QueryStructReq(); - queryStructReq.setModelId(parseInfo.getModelId()); - queryStructReq.setGroups(Lists.newArrayList(TimeDimensionEnum.DAY.getName())); - ModelSchema modelSchema = semanticLayer.getModelSchema(parseInfo.getModelId(), true); - queryStructReq.setAggregators(buildAggregator(modelSchema)); - List filterList = Lists.newArrayList(); - for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) { - Filter filter = new Filter(); - BeanUtils.copyProperties(queryFilter, filter); - filterList.add(filter); - } - queryStructReq.setDimensionFilters(filterList); - DateConf dateConf = new DateConf(); - dateConf.setDateMode(DateConf.DateMode.RECENT); - dateConf.setUnit(7); - queryStructReq.setDateInfo(dateConf); - return semanticLayer.queryByStruct(queryStructReq, user); - } - - private List buildAggregator(ModelSchema modelSchema) { - List aggregators = Lists.newArrayList(); - Set metrics = modelSchema.getMetrics(); - if (CollectionUtils.isEmpty(metrics)) { - return aggregators; - } - for (SchemaElement schemaElement : metrics) { - Aggregator aggregator = new Aggregator(); - aggregator.setColumn(schemaElement.getBizName()); - aggregator.setFunc(AggOperatorEnum.SUM); - aggregator.setNameCh(schemaElement.getName()); - aggregators.add(aggregator); - } - return aggregators; - } - - - public String generateDataText(QueryResultWithSchemaResp queryResultWithSchemaResp) { - Map map = queryResultWithSchemaResp.getColumns().stream() - .collect(Collectors.toMap(QueryColumn::getNameEn, QueryColumn::getName)); - StringBuilder stringBuilder = new StringBuilder(); - for (Map valueMap : queryResultWithSchemaResp.getResultList()) { - for (String key : valueMap.keySet()) { - String name = ""; - if (TimeDimensionEnum.getNameList().contains(key)) { - name = "日期"; - } else { - name = map.get(key); - } - String value = String.valueOf(valueMap.get(key)); - stringBuilder.append(name).append(":").append(value).append(" "); - } - } - return stringBuilder.toString(); - } - - - public String fetchInterpret(String queryText, String dataText) { - PluginManager pluginManager = ContextUtils.getBean(PluginManager.class); - LLmAnswerReq lLmAnswerReq = new LLmAnswerReq(); - lLmAnswerReq.setQueryText(queryText); - lLmAnswerReq.setPluginOutput(dataText); - ResponseEntity responseEntity = pluginManager.doRequest("answer_with_plugin_call", - JSONObject.toJSONString(lLmAnswerReq)); - LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class); - if (lLmAnswerResp != null) { - return lLmAnswerResp.getAssistant_message(); - } - return null; - } - - -} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java index 31f892f37..7f14b46d0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/HeuristicQuerySelector.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.query; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.query.rule.RuleSemanticQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricEntityQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; @@ -18,7 +19,7 @@ public class HeuristicQuerySelector implements QuerySelector { private static final double CANDIDATE_THRESHOLD = 0.2; @Override - public List select(List candidateQueries) { + public List select(List candidateQueries, QueryReq queryReq) { List selectedQueries = new ArrayList<>(); if (CollectionUtils.isNotEmpty(candidateQueries) && candidateQueries.size() == 1) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java index fdfe5a4d0..51ecf9f50 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/QuerySelector.java @@ -1,6 +1,8 @@ package com.tencent.supersonic.chat.query; import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; + import java.util.List; /** @@ -8,5 +10,5 @@ import java.util.List; **/ public interface QuerySelector { - List select(List candidateQueries); + List select(List candidateQueries, QueryReq queryReq); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLBuilder.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLBuilder.java deleted file mode 100644 index fef2077aa..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLBuilder.java +++ /dev/null @@ -1,78 +0,0 @@ -package com.tencent.supersonic.chat.query.dsl; - -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.common.util.StringUtil; -import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; -import com.tencent.supersonic.knowledge.service.SchemaService; -import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; - -@Slf4j -public class DSLBuilder { - - public static final String DATA_Field = "数据日期"; - public static final String TABLE_PREFIX = "t_"; - - public String build(SemanticParseInfo parseInfo, QueryFilters queryFilters, LLMResp llmResp, Long modelId) - throws Exception { - - String sqlOutput = llmResp.getSqlOutput(); - SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); - List dbAllFields = new ArrayList<>(); - dbAllFields.addAll(semanticSchema.getMetrics()); - dbAllFields.addAll(semanticSchema.getDimensions()); - - Map fieldToBizName = getMapInfo(modelId, dbAllFields); - fieldToBizName.put(DATA_Field, TimeDimensionEnum.DAY.getName()); - - sqlOutput = CCJSqlParserUtils.replaceFields(sqlOutput, fieldToBizName); - - sqlOutput = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId); - - String queryFilter = getQueryFilter(queryFilters); - if (StringUtils.isNotEmpty(queryFilter)) { - log.info("add queryFilter to sql :{}", queryFilter); - Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); - CCJSqlParserUtils.addWhere(sqlOutput, expression); - } - - log.info("build sqlOutput:{}", sqlOutput); - return sqlOutput; - } - - protected Map getMapInfo(Long modelId, List metrics) { - return metrics.stream().filter(entry -> entry.getModel().equals(modelId)) - .collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1)); - } - - private String getQueryFilter(QueryFilters queryFilters) { - if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { - return ""; - } - List filters = queryFilters.getFilters(); - - return filters.stream() - .map(filter -> { - String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); - String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); - String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); - return bizNameWrap + operatorWrap + valueWrap; - }) - .collect(Collectors.joining(Constants.AND_UPPER)); - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLQuery.java index fc09b9637..2971ddc97 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/DSLQuery.java @@ -2,14 +2,14 @@ package com.tencent.supersonic.chat.query.dsl; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticLayer; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.parser.llm.DSLParseResult; +import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.chat.api.component.DSLOptimizer; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.utils.ComponentFactory; @@ -32,7 +32,6 @@ import org.springframework.stereotype.Component; public class DSLQuery extends PluginSemanticQuery { public static final String QUERY_MODE = "DSL"; - private DSLBuilder dslBuilder = new DSLBuilder(); protected SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); public DSLQuery() { @@ -51,12 +50,26 @@ public class DSLQuery extends PluginSemanticQuery { LLMResp llmResp = dslParseResult.getLlmResp(); QueryReq queryReq = dslParseResult.getRequest(); - Long modelId = parseInfo.getModelId(); - String querySql = convertToSql(queryReq.getQueryFilters(), llmResp, parseInfo, modelId); + CorrectionInfo correctionInfo = CorrectionInfo.builder() + .queryFilters(queryReq.getQueryFilters()) + .sql(llmResp.getSqlOutput()) + .parseInfo(parseInfo) + .build(); + + List DSLCorrections = ComponentFactory.getSqlCorrections(); + + DSLCorrections.forEach(DSLCorrection -> { + try { + DSLCorrection.rewriter(correctionInfo); + log.info("sqlCorrection:{} sql:{}", DSLCorrection.getClass().getSimpleName(), correctionInfo.getSql()); + } catch (Exception e) { + log.error("sqlCorrection:{} execute error,correctionInfo:{}", DSLCorrection, correctionInfo, e); + } + }); + String querySql = correctionInfo.getSql(); long startTime = System.currentTimeMillis(); - - QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, modelId); + QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(querySql, parseInfo.getModelId()); QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user); log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, querySql); @@ -80,17 +93,4 @@ public class DSLQuery extends PluginSemanticQuery { parseInfo.setProperties(null); return queryResult; } - - - protected String convertToSql(QueryFilters queryFilters, LLMResp llmResp, SemanticParseInfo parseInfo, - Long modelId) { - try { - return dslBuilder.build(parseInfo, queryFilters, llmResp, modelId); - } catch (Exception e) { - log.error("convertToSql error", e); - } - return null; - } - - } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/BaseDSLOptimizer.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/BaseDSLOptimizer.java new file mode 100644 index 000000000..72810d7c6 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/BaseDSLOptimizer.java @@ -0,0 +1,33 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.component.DSLOptimizer; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public abstract class BaseDSLOptimizer implements DSLOptimizer { + public static final String DATE_FIELD = "数据日期"; + protected Map getFieldToBizName(Long modelId) { + + SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); + + List dbAllFields = new ArrayList<>(); + dbAllFields.addAll(semanticSchema.getMetrics()); + dbAllFields.addAll(semanticSchema.getDimensions()); + + Map result = dbAllFields.stream() + .filter(entry -> entry.getModel().equals(modelId)) + .collect(Collectors.toMap(SchemaElement::getName, a -> a.getBizName(), (k1, k2) -> k1)); + result.put(DATE_FIELD, TimeDimensionEnum.DAY.getName()); + return result; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrector.java new file mode 100644 index 000000000..14ce36d23 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrector.java @@ -0,0 +1,26 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; +import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; +import java.util.List; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; + +@Slf4j +public class DateFieldCorrector extends BaseDSLOptimizer { + + @Override + public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { + + String sql = correctionInfo.getSql(); + List whereFields = CCJSqlParserUtils.getWhereFields(sql); + if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(BaseDSLOptimizer.DATE_FIELD)) { + String currentDate = DSLDateHelper.getCurrentDate(correctionInfo.getParseInfo().getModelId()); + sql = CCJSqlParserUtils.addWhere(sql, BaseDSLOptimizer.DATE_FIELD, currentDate); + } + correctionInfo.setSql(sql); + return correctionInfo; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FieldCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FieldCorrector.java new file mode 100644 index 000000000..1ff0f1ed2 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FieldCorrector.java @@ -0,0 +1,17 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class FieldCorrector extends BaseDSLOptimizer { + + @Override + public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { + String replaceFields = CCJSqlParserUtils.replaceFields(correctionInfo.getSql(), + getFieldToBizName(correctionInfo.getParseInfo().getModelId())); + correctionInfo.setSql(replaceFields); + return correctionInfo; + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FunctionCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FunctionCorrector.java new file mode 100644 index 000000000..f5cf6af4c --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/FunctionCorrector.java @@ -0,0 +1,16 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class FunctionCorrector extends BaseDSLOptimizer { + + @Override + public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { + String replaceFunction = CCJSqlParserUtils.replaceFunction(correctionInfo.getSql()); + correctionInfo.setSql(replaceFunction); + return correctionInfo; + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/QueryFilterAppend.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/QueryFilterAppend.java new file mode 100644 index 000000000..6e6806b47 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/QueryFilterAppend.java @@ -0,0 +1,48 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.util.StringUtil; +import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; +import java.util.Objects; +import java.util.stream.Collectors; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import org.apache.commons.collections.CollectionUtils; +import org.apache.commons.lang3.StringUtils; + +@Slf4j +public class QueryFilterAppend extends BaseDSLOptimizer { + + @Override + public CorrectionInfo rewriter(CorrectionInfo correctionInfo) throws JSQLParserException { + String queryFilter = getQueryFilter(correctionInfo.getQueryFilters()); + String sql = correctionInfo.getSql(); + + if (StringUtils.isNotEmpty(queryFilter)) { + log.info("add queryFilter to sql :{}", queryFilter); + Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); + sql = CCJSqlParserUtils.addWhere(sql, expression); + } + correctionInfo.setSql(sql); + return correctionInfo; + } + + private String getQueryFilter(QueryFilters queryFilters) { + if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { + return null; + } + return queryFilters.getFilters().stream() + .map(filter -> { + String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); + String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); + String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); + return bizNameWrap + operatorWrap + valueWrap; + }) + .collect(Collectors.joining(Constants.AND_UPPER)); + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrector.java new file mode 100644 index 000000000..33e1e96c9 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrector.java @@ -0,0 +1,35 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; +import lombok.extern.slf4j.Slf4j; +import org.springframework.util.CollectionUtils; + +@Slf4j +public class SelectFieldAppendCorrector extends BaseDSLOptimizer { + + @Override + public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { + String sql = correctionInfo.getSql(); + if (CCJSqlParserUtils.hasAggregateFunction(sql)) { + return correctionInfo; + } + Set selectFields = new HashSet<>(CCJSqlParserUtils.getSelectFields(sql)); + Set whereFields = new HashSet<>(CCJSqlParserUtils.getWhereFields(sql)); + if (CollectionUtils.isEmpty(selectFields) || CollectionUtils.isEmpty(whereFields)) { + return correctionInfo; + } + + whereFields.removeAll(selectFields); + whereFields.remove(TimeDimensionEnum.DAY.getName()); + whereFields.remove(TimeDimensionEnum.WEEK.getName()); + whereFields.remove(TimeDimensionEnum.MONTH.getName()); + String replaceFields = CCJSqlParserUtils.addFieldsToSelect(sql, new ArrayList<>(whereFields)); + correctionInfo.setSql(replaceFields); + return correctionInfo; + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/TableNameCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/TableNameCorrector.java new file mode 100644 index 000000000..1604e4b9c --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/dsl/optimizer/TableNameCorrector.java @@ -0,0 +1,21 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.common.util.jsqlparser.CCJSqlParserUtils; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class TableNameCorrector extends BaseDSLOptimizer { + + public static final String TABLE_PREFIX = "t_"; + + @Override + public CorrectionInfo rewriter(CorrectionInfo correctionInfo) { + Long modelId = correctionInfo.getParseInfo().getModelId(); + String sqlOutput = correctionInfo.getSql(); + String replaceTable = CCJSqlParserUtils.replaceTable(sqlOutput, TABLE_PREFIX + modelId); + correctionInfo.setSql(replaceTable); + return correctionInfo; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/LLmAnswerReq.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/LLmAnswerReq.java similarity index 67% rename from chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/LLmAnswerReq.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/LLmAnswerReq.java index d678c80f3..bab24c6ca 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/LLmAnswerReq.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/LLmAnswerReq.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.query.ContentInterpret; +package com.tencent.supersonic.chat.query.metricInterpret; import lombok.Data; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/LLmAnswerResp.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/LLmAnswerResp.java similarity index 62% rename from chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/LLmAnswerResp.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/LLmAnswerResp.java index 728b696fe..32f4991ea 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/ContentInterpret/LLmAnswerResp.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/LLmAnswerResp.java @@ -1,4 +1,4 @@ -package com.tencent.supersonic.chat.query.ContentInterpret; +package com.tencent.supersonic.chat.query.metricInterpret; import lombok.Data; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/MetricInterpretQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/MetricInterpretQuery.java new file mode 100644 index 000000000..2c48f1fe1 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/metricInterpret/MetricInterpretQuery.java @@ -0,0 +1,143 @@ +package com.tencent.supersonic.chat.query.metricInterpret; + +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.api.component.SemanticLayer; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; +import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.response.QueryResult; +import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.plugin.PluginManager; +import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; +import com.tencent.supersonic.chat.utils.ComponentFactory; +import com.tencent.supersonic.chat.utils.QueryReqBuilder; +import com.tencent.supersonic.common.pojo.Aggregator; +import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; +import lombok.extern.slf4j.Slf4j; +import org.apache.calcite.sql.parser.SqlParseException; +import org.apache.commons.lang3.StringUtils; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; +import java.util.*; +import java.util.stream.Collectors; + +@Slf4j +@Component +public class MetricInterpretQuery extends PluginSemanticQuery { + + + public final static String QUERY_MODE = "METRIC_INTERPRET"; + + @Override + public String getQueryMode() { + return QUERY_MODE; + } + + public MetricInterpretQuery() { + QueryManager.register(this); + } + + @Override + public QueryResult execute(User user) throws SqlParseException { + QueryStructReq queryStructReq = QueryReqBuilder.buildStructReq(parseInfo); + fillAggregator(queryStructReq, parseInfo.getMetrics()); + queryStructReq.setNativeQuery(true); + SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); + QueryResultWithSchemaResp queryResultWithSchemaResp = semanticLayer.queryByStruct(queryStructReq, user); + String text = generateTableText(queryResultWithSchemaResp); + Map properties = parseInfo.getProperties(); + Map replacedMap = new HashMap<>(); + String textReplaced = replaceText((String) properties.get("queryText"), parseInfo.getElementMatches(), replacedMap); + String answer = replaceAnswer(fetchInterpret(textReplaced, text), replacedMap); + QueryResult queryResult = new QueryResult(); + List queryColumns = Lists.newArrayList(new QueryColumn("结果","string","answer")); + Map result = new HashMap<>(); + result.put("answer", answer); + List> resultList = Lists.newArrayList(); + resultList.add(result); + queryResult.setQueryResults(resultList); + queryResult.setQueryColumns(queryColumns); + queryResult.setQueryMode(getQueryMode()); + queryResult.setQueryState(QueryState.SUCCESS); + return queryResult; + } + + private String replaceText(String text, List schemaElementMatches, Map replacedMap) { + if (CollectionUtils.isEmpty(schemaElementMatches)) { + return text; + } + List valueSchemaElementMatches = schemaElementMatches.stream() + .filter(schemaElementMatch -> + SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) + || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toList()); + for (SchemaElementMatch schemaElementMatch : valueSchemaElementMatches) { + String detectWord = schemaElementMatch.getDetectWord(); + if (StringUtils.isBlank(detectWord)) { + continue; + } + text = text.replace(detectWord, "xxx"); + replacedMap.put("xxx", detectWord); + } + return text; + } + + private void fillAggregator(QueryStructReq queryStructReq, Set schemaElements) { + queryStructReq.getAggregators().clear(); + for (SchemaElement schemaElement : schemaElements) { + Aggregator aggregator = new Aggregator(); + aggregator.setColumn(schemaElement.getBizName()); + aggregator.setFunc(AggOperatorEnum.SUM); + aggregator.setNameCh(schemaElement.getName()); + queryStructReq.getAggregators().add(aggregator); + } + } + + + private String replaceAnswer(String text, Map replacedMap) { + for (String key : replacedMap.keySet()) { + text = text.replaceAll(key, replacedMap.get(key)); + } + return text; + } + + public static String generateTableText(QueryResultWithSchemaResp result) { + StringBuilder tableBuilder = new StringBuilder(); + for (QueryColumn column : result.getColumns()) { + tableBuilder.append(column.getName()).append("\t"); + } + tableBuilder.append("\n"); + for (Map row : result.getResultList()) { + for (QueryColumn column : result.getColumns()) { + tableBuilder.append(row.get(column.getNameEn())).append("\t"); + } + tableBuilder.append("\n"); + } + return tableBuilder.toString(); + } + + + public String fetchInterpret(String queryText, String dataText) { + PluginManager pluginManager = ContextUtils.getBean(PluginManager.class); + LLmAnswerReq lLmAnswerReq = new LLmAnswerReq(); + lLmAnswerReq.setQueryText(queryText); + lLmAnswerReq.setPluginOutput(dataText); + ResponseEntity responseEntity = pluginManager.doRequest("answer_with_plugin_call", + JSONObject.toJSONString(lLmAnswerReq)); + LLmAnswerResp lLmAnswerResp = JSONObject.parseObject(responseEntity.getBody(), LLmAnswerResp.class); + if (lLmAnswerResp != null) { + return lLmAnswerResp.getAssistant_message(); + } + return null; + } + + +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java index 3c92e6d01..82f51f705 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/webpage/WebPageQuery.java @@ -2,15 +2,14 @@ package com.tencent.supersonic.chat.query.plugin.webpage; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.pojo.ModelSchema; -import com.tencent.supersonic.chat.api.pojo.SchemaElementMatch; -import com.tencent.supersonic.chat.api.pojo.SchemaElementType; +import com.tencent.supersonic.chat.api.pojo.*; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.api.pojo.response.ChatConfigRichResp; import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.PluginParseResult; import com.tencent.supersonic.chat.query.QueryManager; @@ -18,18 +17,18 @@ import com.tencent.supersonic.chat.query.plugin.ParamOption; import com.tencent.supersonic.chat.query.plugin.PluginSemanticQuery; import com.tencent.supersonic.chat.query.plugin.WebBase; import com.tencent.supersonic.chat.query.plugin.WebBaseResult; +import com.tencent.supersonic.chat.service.ConfigService; import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; import org.springframework.util.CollectionUtils; +import java.util.*; +import java.util.stream.Collectors; + @Slf4j @Component public class WebPageQuery extends PluginSemanticQuery { @@ -107,17 +106,15 @@ public class WebPageQuery extends PluginSemanticQuery { .filter(schemaElementMatch -> SchemaElementType.VALUE.equals(schemaElementMatch.getElement().getType()) || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) - .sorted(Comparator.comparingDouble(SchemaElementMatch::getSimilarity)) + .filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0) .forEach(schemaElementMatch -> { Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId()); if (queryFilterValue != null) { if (String.valueOf(queryFilterValue).equals(String.valueOf(schemaElementMatch.getWord()))) { - elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()), - schemaElementMatch.getWord()); + elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()), schemaElementMatch.getWord()); } } else { - elementValueMap.put(String.valueOf(schemaElementMatch.getElement().getId()), - schemaElementMatch.getWord()); + elementValueMap.computeIfAbsent(String.valueOf(schemaElementMatch.getElement().getId()), k -> schemaElementMatch.getWord()); } }); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index ee445776a..873354b5f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -208,8 +208,7 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { } QueryResult queryResult = new QueryResult(); - QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct( - convertQueryStruct(), user); + QueryResultWithSchemaResp queryResp = semanticLayer.queryByStruct(convertQueryStruct(), user); if (queryResp != null) { queryResult.setQueryAuthorization(queryResp.getQueryAuthorization()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityFilterQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityFilterQuery.java index 31e1039fe..bc88440c3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityFilterQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityFilterQuery.java @@ -1,9 +1,9 @@ package com.tencent.supersonic.chat.query.rule.entity; -import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID; -import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.VALUE; +import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.*; import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.OPTIONAL; -import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST; +import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED; +import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.*; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -17,8 +17,7 @@ public class EntityFilterQuery extends EntityListQuery { public EntityFilterQuery() { super(); - queryMatcher.addOption(VALUE, OPTIONAL, AT_LEAST, 0); - queryMatcher.addOption(ID, OPTIONAL, AT_LEAST, 0); + queryMatcher.addOption(VALUE, REQUIRED, AT_LEAST, 1); } @Override diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityIdQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityIdQuery.java new file mode 100644 index 000000000..ad162effd --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/entity/EntityIdQuery.java @@ -0,0 +1,23 @@ +package com.tencent.supersonic.chat.query.rule.entity; + +import org.springframework.stereotype.Component; +import static com.tencent.supersonic.chat.api.pojo.SchemaElementType.ID; +import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.OptionType.REQUIRED; +import static com.tencent.supersonic.chat.query.rule.QueryMatchOption.RequireNumberType.AT_LEAST; + +@Component +public class EntityIdQuery extends EntityListQuery { + + public static final String QUERY_MODE = "ENTITY_ID"; + + public EntityIdQuery() { + super(); + queryMatcher.addOption(ID, REQUIRED, AT_LEAST, 1); + } + + @Override + public String getQueryMode() { + return QUERY_MODE; + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/AgentController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/AgentController.java new file mode 100644 index 000000000..49c6c5474 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/AgentController.java @@ -0,0 +1,51 @@ +package com.tencent.supersonic.chat.rest; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.service.AgentService; +import org.springframework.web.bind.annotation.*; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.util.List; + +@RestController +@RequestMapping("/api/chat/agent") +public class AgentController { + + private AgentService agentService; + + public AgentController(AgentService agentService) { + this.agentService = agentService; + } + + @PostMapping + public boolean createAgent(@RequestBody Agent agent, + HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { + User user = UserHolder.findUser(httpServletRequest, httpServletResponse); + agentService.createAgent(agent, user); + return true; + } + + @PutMapping + public boolean updateAgent(@RequestBody Agent agent, + HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { + User user = UserHolder.findUser(httpServletRequest, httpServletResponse); + agentService.updateAgent(agent, user); + return true; + } + + @DeleteMapping("/{id}") + public boolean deleteAgent(@PathVariable("id") Integer id) { + agentService.deleteAgent(id); + return true; + } + + @RequestMapping("/getAgentList") + public List getAgentList() { + return agentService.getAgents(); + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java index b3da4bac9..6f34d8dda 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java @@ -32,7 +32,7 @@ import org.springframework.web.bind.annotation.RestController; @RestController -@RequestMapping("/api/chat/conf") +@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"}) public class ChatConfigController { @Autowired diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java index 44712a0c0..d836660eb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatController.java @@ -18,7 +18,7 @@ import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; @RestController -@RequestMapping("/api/chat/manage") +@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"}) public class ChatController { private final ChatService chatService; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java index 4aaadce89..43702f9c1 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java @@ -20,7 +20,7 @@ import org.springframework.web.bind.annotation.RestController; * query controller */ @RestController -@RequestMapping("/api/chat/query") +@RequestMapping({"/api/chat/query", "/openapi/chat/query"}) public class ChatQueryController { @Autowired diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/RecommendController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/RecommendController.java index 8a81b072e..dbeb5a449 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/RecommendController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/RecommendController.java @@ -19,7 +19,7 @@ import org.springframework.web.bind.annotation.RestController; * recommend controller */ @RestController -@RequestMapping("/api/chat/") +@RequestMapping({"/api/chat/", "/openapi/chat/"}) public class RecommendController { @Autowired diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java new file mode 100644 index 000000000..2d247f6ea --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/AgentService.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.chat.service; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.agent.Agent; +import java.util.List; + +public interface AgentService { + + List getAgents(); + + void createAgent(Agent agent, User user); + + void updateAgent(Agent agent, User user); + + Agent getAgent(Integer id); + + void deleteAgent(Integer id); + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java new file mode 100644 index 000000000..b4aa7dc86 --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/AgentServiceImpl.java @@ -0,0 +1,82 @@ +package com.tencent.supersonic.chat.service.impl; + +import com.alibaba.fastjson.JSONObject; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.persistence.dataobject.AgentDO; +import com.tencent.supersonic.chat.persistence.repository.AgentRepository; +import com.tencent.supersonic.chat.service.AgentService; +import org.springframework.beans.BeanUtils; +import org.springframework.stereotype.Service; +import java.util.Date; +import java.util.List; +import java.util.stream.Collectors; + +@Service +public class AgentServiceImpl implements AgentService { + + private AgentRepository agentRepository; + + public AgentServiceImpl(AgentRepository agentRepository) { + this.agentRepository = agentRepository; + } + + @Override + public List getAgents() { + return getAgentDOList().stream() + .map(this::convert).collect(Collectors.toList()); + } + + @Override + public void createAgent(Agent agent, User user) { + agentRepository.createAgent(convert(agent, user)); + } + + @Override + public void updateAgent(Agent agent, User user) { + agentRepository.updateAgent(convert(agent, user)); + } + + @Override + public Agent getAgent(Integer id) { + if (id == null) { + return null; + } + return convert(agentRepository.getAgent(id)); + } + + @Override + public void deleteAgent(Integer id) { + agentRepository.deleteAgent(id); + } + + private List getAgentDOList() { + return agentRepository.getAgents(); + } + + private Agent convert(AgentDO agentDO){ + if (agentDO == null ) { + return null; + } + Agent agent = new Agent(); + BeanUtils.copyProperties(agentDO,agent); + agent.setAgentConfig(agentDO.getConfig()); + agent.setExamples(JSONObject.parseArray(agentDO.getExamples(), String.class)); + return agent; + } + + private AgentDO convert(Agent agent, User user){ + AgentDO agentDO = new AgentDO(); + BeanUtils.copyProperties(agent, agentDO); + agentDO.setConfig(agent.getAgentConfig()); + agentDO.setExamples(JSONObject.toJSONString(agent.getExamples())); + agentDO.setCreatedAt(new Date()); + agentDO.setCreatedBy(user.getName()); + agentDO.setUpdatedAt(new Date()); + agentDO.setUpdatedBy(user.getName()); + if (agentDO.getStatus() == null) { + agentDO.setStatus(1); + } + return agentDO; + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index dc6c08d2a..cedb2689c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -2,26 +2,25 @@ package com.tencent.supersonic.chat.service.impl; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.chat.api.component.SchemaMapper; -import com.tencent.supersonic.chat.api.component.SemanticParser; -import com.tencent.supersonic.chat.api.component.SemanticQuery; +import com.tencent.supersonic.chat.api.component.*; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; -import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; -import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.query.QuerySelector; +import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; +import com.tencent.supersonic.chat.query.QueryManager; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; import com.tencent.supersonic.chat.utils.ComponentFactory; -import com.tencent.supersonic.common.util.JsonUtil; +import java.util.Comparator; import java.util.List; import java.util.stream.Collectors; +import com.tencent.supersonic.common.util.JsonUtil; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; import org.springframework.beans.BeanUtils; @@ -63,12 +62,14 @@ public class QueryServiceImpl implements QueryService { if (queryCtx.getCandidateQueries().size() > 0) { log.debug("pick before [{}]", queryCtx.getCandidateQueries().stream().collect( Collectors.toList())); - List selectedQueries = querySelector.select(queryCtx.getCandidateQueries()); + List selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq); log.debug("pick after [{}]", selectedQueries.stream().collect( Collectors.toList())); List selectedParses = selectedQueries.stream() - .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); + .map(SemanticQuery::getParseInfo) + .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) + .collect(Collectors.toList()); List candidateParses = queryCtx.getCandidateQueries().stream() .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); @@ -138,7 +139,7 @@ public class QueryServiceImpl implements QueryService { if (queryCtx.getCandidateQueries().size() > 0) { log.info("pick before [{}]", queryCtx.getCandidateQueries().stream().collect( Collectors.toList())); - List selectedQueries = querySelector.select(queryCtx.getCandidateQueries()); + List selectedQueries = querySelector.select(queryCtx.getCandidateQueries(), queryReq); log.info("pick after [{}]", selectedQueries.stream().collect( Collectors.toList())); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/RecommendServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/RecommendServiceImpl.java index 07eb880fd..d5e8380d9 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/RecommendServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/RecommendServiceImpl.java @@ -57,6 +57,7 @@ public class RecommendServiceImpl implements RecommendService { item.setName(dimSchemaDesc.getName()); item.setBizName(dimSchemaDesc.getBizName()); item.setId(dimSchemaDesc.getId()); + item.setAlias(dimSchemaDesc.getAlias()); return item; }).collect(Collectors.toList()); @@ -70,6 +71,7 @@ public class RecommendServiceImpl implements RecommendService { item.setName(metricSchemaDesc.getName()); item.setBizName(metricSchemaDesc.getBizName()); item.setId(metricSchemaDesc.getId()); + item.setAlias(metricSchemaDesc.getAlias()); return item; }).collect(Collectors.toList()); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java index c21baba7a..599a14a5a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/SearchServiceImpl.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.chat.service.impl; import com.google.common.collect.Lists; import com.hankcs.hanlp.seg.common.Term; +import com.tencent.supersonic.chat.agent.Agent; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; @@ -13,6 +14,7 @@ import com.tencent.supersonic.chat.mapper.MatchText; import com.tencent.supersonic.chat.mapper.ModelInfoStat; import com.tencent.supersonic.chat.mapper.ModelWithSemanticType; import com.tencent.supersonic.chat.mapper.SearchMatchStrategy; +import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.SearchService; import com.tencent.supersonic.chat.utils.NatureHelper; @@ -53,22 +55,34 @@ public class SearchServiceImpl implements SearchService { private ChatService chatService; @Autowired private SearchMatchStrategy searchMatchStrategy; + @Autowired + private AgentService agentService; @Override public List search(QueryReq queryCtx) { + + // 1. check search enable + Integer agentId = queryCtx.getAgentId(); + if (agentId != null) { + Agent agent = agentService.getAgent(agentId); + if (!agent.enableSearch()) { + return Lists.newArrayList(); + } + } + String queryText = queryCtx.getQueryText(); - // 1.get meta info + // 2.get meta info SemanticSchema semanticSchemaDb = schemaService.getSemanticSchema(); List metricsDb = semanticSchemaDb.getMetrics(); final Map modelToName = semanticSchemaDb.getModelIdToName(); - // 2.detect by segment + // 3.detect by segment List originals = HanlpHelper.getTerms(queryText); Map> regTextMap = searchMatchStrategy.match(queryText, originals, queryCtx.getModelId()); regTextMap.entrySet().stream().forEach(m -> HanlpHelper.transLetterOriginal(m.getValue())); - // 3.get the most matching data + // 4.get the most matching data Optional>> mostSimilarSearchResult = regTextMap.entrySet() .stream() .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) @@ -77,7 +91,7 @@ public class SearchServiceImpl implements SearchService { ? entry1 : entry2); log.debug("mostSimilarSearchResult:{}", mostSimilarSearchResult); - // 4.optimize the results after the query + // 5.optimize the results after the query if (!mostSimilarSearchResult.isPresent()) { return Lists.newArrayList(); } @@ -89,11 +103,11 @@ public class SearchServiceImpl implements SearchService { List possibleModels = getPossibleModels(queryCtx, originals, modelStat, queryCtx.getModelId()); - // 4.1 priority dimension metric + // 5.1 priority dimension metric boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleModels), modelToName, searchTextEntry, searchResults); - // 4.2 process based on dimension values + // 5.2 process based on dimension values MatchText matchText = searchTextEntry.getKey(); Map natureToNameMap = getNatureToNameMap(searchTextEntry, new HashSet<>(possibleModels)); log.debug("possibleModels:{},natureToNameMap:{}", possibleModels, natureToNameMap); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ChatGptHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ChatGptHelper.java deleted file mode 100644 index cc5d193dd..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ChatGptHelper.java +++ /dev/null @@ -1,77 +0,0 @@ -package com.tencent.supersonic.chat.utils; - - -import com.plexpt.chatgpt.ChatGPT; -import com.plexpt.chatgpt.entity.chat.ChatCompletion; -import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse; -import com.plexpt.chatgpt.entity.chat.Message; -import com.plexpt.chatgpt.util.Proxys; -import java.net.Proxy; -import java.text.SimpleDateFormat; -import java.util.Arrays; -import java.util.Date; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.stereotype.Component; - - -@Component -public class ChatGptHelper { - - @Value("${llm.chatgpt.apikey:xx-xxxx}") - private String apiKey; - - @Value("${llm.chatgpt.apiHost:https://api.openai.com/}") - private String apiHost; - - @Value("${llm.chatgpt.proxyIp:default}") - private String proxyIp; - - @Value("${llm.chatgpt.proxyPort:8080}") - private Integer proxyPort; - - - public ChatGPT getChatGPT() { - Proxy proxy = null; - if (!"default".equals(proxyIp)) { - proxy = Proxys.http(proxyIp, proxyPort); - } - return ChatGPT.builder() - .apiKey(apiKey) - .proxy(proxy) - .timeout(900) - .apiHost(apiHost) //反向代理地址 - .build() - .init(); - } - - public String inferredTime(String queryText) { - long nowTime = System.currentTimeMillis(); - Date date = new Date(nowTime); - SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd"); - String formattedDate = sdf.format(date); - Message system = Message.ofSystem("现在时间 " + formattedDate + ",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。" - + "你需要遵守以下规则:\n" - + "1.返回规范的数据格式,json,如: 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}" - + "2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n" - + "3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n" - + "4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n" - + "5.时间给出要是绝对正确,不能瞎编\n" - ); - Message message = Message.of("输入:" + queryText + ",输出:"); - ChatCompletion chatCompletion = ChatCompletion.builder() - .model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName()) - .messages(Arrays.asList(system, message)) - .maxTokens(10000) - .temperature(0.9) - .build(); - ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion); - Message res = response.getChoices().get(0).getMessage(); - return res.getContent(); - } - - public static void main(String[] args) { - - } - - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java index d47479f9e..e32ebba1b 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/ComponentFactory.java @@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.utils; import com.tencent.supersonic.chat.api.component.SchemaMapper; import com.tencent.supersonic.chat.api.component.SemanticLayer; import com.tencent.supersonic.chat.api.component.SemanticParser; -import com.tencent.supersonic.chat.parser.function.ModelResolver; -import com.tencent.supersonic.chat.query.QuerySelector; + +import com.tencent.supersonic.chat.api.component.DSLOptimizer; import java.util.ArrayList; import java.util.List; import java.util.Objects; + +import com.tencent.supersonic.chat.parser.function.ModelResolver; +import com.tencent.supersonic.chat.query.QuerySelector; import org.apache.commons.collections.CollectionUtils; import org.springframework.core.io.support.SpringFactoriesLoader; @@ -15,10 +18,11 @@ public class ComponentFactory { private static List schemaMappers = new ArrayList<>(); private static List semanticParsers = new ArrayList<>(); + + private static List dslCorrections = new ArrayList<>(); private static SemanticLayer semanticLayer; private static QuerySelector querySelector; private static ModelResolver modelResolver; - public static List getSchemaMappers() { return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers; } @@ -27,6 +31,11 @@ public class ComponentFactory { return CollectionUtils.isEmpty(semanticParsers) ? init(SemanticParser.class, semanticParsers) : semanticParsers; } + public static List getSqlCorrections() { + return CollectionUtils.isEmpty(dslCorrections) ? init(DSLOptimizer.class, dslCorrections) : dslCorrections; + } + + public static SemanticLayer getSemanticLayer() { if (Objects.isNull(semanticLayer)) { semanticLayer = init(SemanticLayer.class); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/NatureHelper.java b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/NatureHelper.java index 51e00ec1c..bf9bd411f 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/utils/NatureHelper.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/utils/NatureHelper.java @@ -74,7 +74,7 @@ public class NatureHelper { return null; } - public static boolean isDimensionValueClassId(String nature) { + public static boolean isDimensionValueModelId(String nature) { if (StringUtils.isEmpty(nature)) { return false; } @@ -104,7 +104,7 @@ public class NatureHelper { } private static long getDimensionValueCount(List terms) { - return terms.stream().filter(term -> isDimensionValueClassId(term.nature.toString())).count(); + return terms.stream().filter(term -> isDimensionValueModelId(term.nature.toString())).count(); } private static long getDimensionCount(List terms) { diff --git a/chat/core/src/main/python/bin/env.sh b/chat/core/src/main/python/bin/env.sh index 6c815b871..271d50b30 100644 --- a/chat/core/src/main/python/bin/env.sh +++ b/chat/core/src/main/python/bin/env.sh @@ -1,4 +1,3 @@ -#!/usr/bin/env bash # python path export python_path="/usr/local/bin/python3.9" # pip path diff --git a/chat/core/src/main/resources/mapper/AgentDOMapper.xml b/chat/core/src/main/resources/mapper/AgentDOMapper.xml new file mode 100644 index 000000000..9e1fcd353 --- /dev/null +++ b/chat/core/src/main/resources/mapper/AgentDOMapper.xml @@ -0,0 +1,303 @@ + + + + + + + + + + + + + + + + + + + + + + + + + and ${criterion.condition} + + + and ${criterion.condition} #{criterion.value} + + + and ${criterion.condition} #{criterion.value} and #{criterion.secondValue} + + + and ${criterion.condition} + + #{listItem} + + + + + + + + + + + + + + + + + + and ${criterion.condition} + + + and ${criterion.condition} #{criterion.value} + + + and ${criterion.condition} #{criterion.value} and #{criterion.secondValue} + + + and ${criterion.condition} + + #{listItem} + + + + + + + + + + + id, name, description, status, examples, config, created_by, created_at, updated_by, + updated_at, enable_search + + + + + delete from s2_agent + where id = #{id,jdbcType=INTEGER} + + + insert into s2_agent (id, name, description, + status, examples, config, + created_by, created_at, updated_by, + updated_at, enable_search) + values (#{id,jdbcType=INTEGER}, #{name,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR}, + #{status,jdbcType=INTEGER}, #{examples,jdbcType=VARCHAR}, #{config,jdbcType=VARCHAR}, + #{createdBy,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR}, + #{updatedAt,jdbcType=TIMESTAMP}, #{enableSearch,jdbcType=INTEGER}) + + + insert into s2_agent + + + id, + + + name, + + + description, + + + status, + + + examples, + + + config, + + + created_by, + + + created_at, + + + updated_by, + + + updated_at, + + + enable_search, + + + + + #{id,jdbcType=INTEGER}, + + + #{name,jdbcType=VARCHAR}, + + + #{description,jdbcType=VARCHAR}, + + + #{status,jdbcType=INTEGER}, + + + #{examples,jdbcType=VARCHAR}, + + + #{config,jdbcType=VARCHAR}, + + + #{createdBy,jdbcType=VARCHAR}, + + + #{createdAt,jdbcType=TIMESTAMP}, + + + #{updatedBy,jdbcType=VARCHAR}, + + + #{updatedAt,jdbcType=TIMESTAMP}, + + + #{enableSearch,jdbcType=INTEGER}, + + + + + + update s2_agent + + + id = #{record.id,jdbcType=INTEGER}, + + + name = #{record.name,jdbcType=VARCHAR}, + + + description = #{record.description,jdbcType=VARCHAR}, + + + status = #{record.status,jdbcType=INTEGER}, + + + examples = #{record.examples,jdbcType=VARCHAR}, + + + config = #{record.config,jdbcType=VARCHAR}, + + + created_by = #{record.createdBy,jdbcType=VARCHAR}, + + + created_at = #{record.createdAt,jdbcType=TIMESTAMP}, + + + updated_by = #{record.updatedBy,jdbcType=VARCHAR}, + + + updated_at = #{record.updatedAt,jdbcType=TIMESTAMP}, + + + enable_search = #{record.enableSearch,jdbcType=INTEGER}, + + + + + + + + update s2_agent + set id = #{record.id,jdbcType=INTEGER}, + name = #{record.name,jdbcType=VARCHAR}, + description = #{record.description,jdbcType=VARCHAR}, + status = #{record.status,jdbcType=INTEGER}, + examples = #{record.examples,jdbcType=VARCHAR}, + config = #{record.config,jdbcType=VARCHAR}, + created_by = #{record.createdBy,jdbcType=VARCHAR}, + created_at = #{record.createdAt,jdbcType=TIMESTAMP}, + updated_by = #{record.updatedBy,jdbcType=VARCHAR}, + updated_at = #{record.updatedAt,jdbcType=TIMESTAMP}, + enable_search = #{record.enableSearch,jdbcType=INTEGER} + + + + + + update s2_agent + + + name = #{name,jdbcType=VARCHAR}, + + + description = #{description,jdbcType=VARCHAR}, + + + status = #{status,jdbcType=INTEGER}, + + + examples = #{examples,jdbcType=VARCHAR}, + + + config = #{config,jdbcType=VARCHAR}, + + + created_by = #{createdBy,jdbcType=VARCHAR}, + + + created_at = #{createdAt,jdbcType=TIMESTAMP}, + + + updated_by = #{updatedBy,jdbcType=VARCHAR}, + + + updated_at = #{updatedAt,jdbcType=TIMESTAMP}, + + + enable_search = #{enableSearch,jdbcType=INTEGER}, + + + where id = #{id,jdbcType=INTEGER} + + + update s2_agent + set name = #{name,jdbcType=VARCHAR}, + description = #{description,jdbcType=VARCHAR}, + status = #{status,jdbcType=INTEGER}, + examples = #{examples,jdbcType=VARCHAR}, + config = #{config,jdbcType=VARCHAR}, + created_by = #{createdBy,jdbcType=VARCHAR}, + created_at = #{createdAt,jdbcType=TIMESTAMP}, + updated_by = #{updatedBy,jdbcType=VARCHAR}, + updated_at = #{updatedAt,jdbcType=TIMESTAMP}, + enable_search = #{enableSearch,jdbcType=INTEGER} + where id = #{id,jdbcType=INTEGER} + + \ No newline at end of file diff --git a/chat/core/src/main/resources/sql.ddl/chat.sql b/chat/core/src/main/resources/sql.ddl/chat.sql index c78549748..183e05673 100644 --- a/chat/core/src/main/resources/sql.ddl/chat.sql +++ b/chat/core/src/main/resources/sql.ddl/chat.sql @@ -59,4 +59,18 @@ CREATE TABLE `chat_query` KEY `common` (`question_id`), KEY `common1` (`user_name`), KEY `common2` (`chat_id`) -) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; \ No newline at end of file +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; + + +CREATE TABLE `chat` +( + `chat_id` bigint(8) NOT NULL AUTO_INCREMENT, + `chat_name` varchar(100) DEFAULT NULL, + `create_time` datetime DEFAULT NULL, + `last_time` datetime DEFAULT NULL, + `creator` varchar(30) DEFAULT NULL, + `last_question` varchar(200) DEFAULT NULL, + `is_delete` int(2) DEFAULT '0' COMMENT 'is deleted', + `is_top` int(2) DEFAULT '0' COMMENT 'is top', + PRIMARY KEY (`chat_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8; diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrectorTest.java new file mode 100644 index 000000000..51c08072b --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/DateFieldCorrectorTest.java @@ -0,0 +1,37 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import com.tencent.supersonic.chat.api.pojo.SchemaElement; +import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +class DateFieldCorrectorTest { + + @Test + void rewriter() { + DateFieldCorrector dateFieldCorrector = new DateFieldCorrector(); + SemanticParseInfo parseInfo = new SemanticParseInfo(); + SchemaElement model = new SchemaElement(); + model.setId(2L); + parseInfo.setModel(model); + CorrectionInfo correctionInfo = CorrectionInfo.builder() + .sql("select count(歌曲名) from 歌曲库 ") + .parseInfo(parseInfo) + .build(); + + CorrectionInfo rewriter = dateFieldCorrector.rewriter(correctionInfo); + + Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", rewriter.getSql()); + + correctionInfo = CorrectionInfo.builder() + .sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'") + .parseInfo(parseInfo) + .build(); + + rewriter = dateFieldCorrector.rewriter(correctionInfo); + + Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", rewriter.getSql()); + + } +} \ No newline at end of file diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrectorTest.java new file mode 100644 index 000000000..d796dd12d --- /dev/null +++ b/chat/core/src/test/java/com/tencent/supersonic/chat/query/dsl/optimizer/SelectFieldAppendCorrectorTest.java @@ -0,0 +1,25 @@ +package com.tencent.supersonic.chat.query.dsl.optimizer; + +import static org.junit.jupiter.api.Assertions.*; + +import com.tencent.supersonic.chat.api.pojo.CorrectionInfo; +import org.junit.Assert; +import org.junit.jupiter.api.Test; + +class SelectFieldAppendCorrectorTest { + + @Test + void rewriter() { + SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector(); + CorrectionInfo correctionInfo = CorrectionInfo.builder() + .sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11") + .build(); + + CorrectionInfo rewriter = corrector.rewriter(correctionInfo); + + Assert.assertEquals( + "SELECT 歌曲名, 歌手名, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", + rewriter.getSql()); + + } +} \ No newline at end of file diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/DimensionWordBuilder.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/DimensionWordBuilder.java index f8d5186bd..c991d89e4 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/DimensionWordBuilder.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/DimensionWordBuilder.java @@ -1,13 +1,17 @@ package com.tencent.supersonic.knowledge.dictionary.builder; import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.List; + import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.knowledge.dictionary.DictWord; import com.tencent.supersonic.knowledge.dictionary.DictWordType; -import java.util.List; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; /** * dimension word nature @@ -23,6 +27,7 @@ public class DimensionWordBuilder extends BaseWordBuilder { public List doGet(String word, SchemaElement schemaElement) { List result = Lists.newArrayList(); result.add(getOnwWordNature(word, schemaElement, false)); + result.addAll(getOnwWordNatureAlias(schemaElement, false)); if (nlpDimensionUseSuffix) { String reverseWord = StringUtils.reverse(word); if (StringUtils.isNotEmpty(word) && !word.equalsIgnoreCase(reverseWord)) { @@ -46,4 +51,16 @@ public class DimensionWordBuilder extends BaseWordBuilder { return dictWord; } + private List getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) { + List dictWords = new ArrayList<>(); + if (CollectionUtils.isEmpty(schemaElement.getAlias())) { + return dictWords; + } + + for (String alias : schemaElement.getAlias()) { + dictWords.add(getOnwWordNature(alias, schemaElement, false)); + } + return dictWords; + } + } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/MetricWordBuilder.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/MetricWordBuilder.java index 01fc833e0..d51c3190a 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/MetricWordBuilder.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/dictionary/builder/MetricWordBuilder.java @@ -1,13 +1,17 @@ package com.tencent.supersonic.knowledge.dictionary.builder; import com.google.common.collect.Lists; + +import java.util.ArrayList; +import java.util.List; + import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.knowledge.dictionary.DictWord; import com.tencent.supersonic.knowledge.dictionary.DictWordType; -import java.util.List; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; /** * Metric DictWord @@ -22,6 +26,7 @@ public class MetricWordBuilder extends BaseWordBuilder { public List doGet(String word, SchemaElement schemaElement) { List result = Lists.newArrayList(); result.add(getOnwWordNature(word, schemaElement, false)); + result.addAll(getOnwWordNatureAlias(schemaElement, false)); if (nlpMetricUseSuffix) { String reverseWord = StringUtils.reverse(word); if (!word.equalsIgnoreCase(reverseWord)) { @@ -45,4 +50,16 @@ public class MetricWordBuilder extends BaseWordBuilder { return dictWord; } + private List getOnwWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) { + List dictWords = new ArrayList<>(); + if (CollectionUtils.isEmpty(schemaElement.getAlias())) { + return dictWords; + } + + for (String alias : schemaElement.getAlias()) { + dictWords.add(getOnwWordNature(alias, schemaElement, false)); + } + return dictWords; + } + } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java index 52efe536b..e42c21c0a 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java @@ -2,51 +2,38 @@ package com.tencent.supersonic.knowledge.semantic; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.util.S2ThreadContext; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; -import com.tencent.supersonic.semantic.api.model.response.DimensionResp; -import com.tencent.supersonic.semantic.api.model.response.DomainResp; -import com.tencent.supersonic.semantic.api.model.response.MetricResp; -import com.tencent.supersonic.semantic.api.model.response.ModelResp; -import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; -import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.*; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.model.domain.DimensionService; -import com.tencent.supersonic.semantic.model.domain.DomainService; import com.tencent.supersonic.semantic.model.domain.MetricService; import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.query.service.QueryService; import com.tencent.supersonic.semantic.query.service.SchemaService; import java.util.List; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; @Slf4j public class LocalSemanticLayer extends BaseSemanticLayer { private SchemaService schemaService; - private S2ThreadContext s2ThreadContext; - private DomainService domainService; private ModelService modelService; private DimensionService dimensionService; private MetricService metricService; + @SneakyThrows @Override - public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) { - try { - QueryService queryService = ContextUtils.getBean(QueryService.class); - QueryResultWithSchemaResp queryResultWithSchemaResp = queryService.queryByStruct(queryStructReq, user); - return queryResultWithSchemaResp; - } catch (Exception e) { - log.info("queryByStruct has an exception:{}", e.toString()); - } - return null; + public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user){ + QueryService queryService = ContextUtils.getBean(QueryService.class); + return queryService.queryByStructWithAuth(queryStructReq, user); } @Override diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java index 81620eaf3..43a120325 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/ModelSchemaBuilder.java @@ -8,21 +8,17 @@ import com.tencent.supersonic.semantic.api.model.pojo.Entity; import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp; import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.stream.Collectors; -import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.beans.BeanUtils; import org.springframework.util.CollectionUtils; +import java.util.*; +import java.util.stream.Collectors; + public class ModelSchemaBuilder { + private static String aliasSplit = ","; + public static ModelSchema build(ModelSchemaResp resp) { ModelSchema domainSchema = new ModelSchema(); @@ -37,6 +33,13 @@ public class ModelSchemaBuilder { Set metrics = new HashSet<>(); for (MetricSchemaResp metric : resp.getMetrics()) { + + List alias = new ArrayList<>(); + String aliasStr = metric.getAlias(); + if (Strings.isNotEmpty(aliasStr)) { + alias = Arrays.asList(aliasStr.split(aliasSplit)); + } + SchemaElement metricToAdd = SchemaElement.builder() .model(resp.getId()) .id(metric.getId()) @@ -44,16 +47,10 @@ public class ModelSchemaBuilder { .bizName(metric.getBizName()) .type(SchemaElementType.METRIC) .useCnt(metric.getUseCnt()) + .alias(alias) .build(); metrics.add(metricToAdd); - String alias = metric.getAlias(); - if (StringUtils.isNotEmpty(alias)) { - SchemaElement alisMetricToAdd = new SchemaElement(); - BeanUtils.copyProperties(metricToAdd, alisMetricToAdd); - alisMetricToAdd.setName(alias); - metrics.add(alisMetricToAdd); - } } domainSchema.getMetrics().addAll(metrics); @@ -74,6 +71,11 @@ public class ModelSchemaBuilder { } } + List alias = new ArrayList<>(); + String aliasStr = dim.getAlias(); + if (Strings.isNotEmpty(aliasStr)) { + alias = Arrays.asList(aliasStr.split(aliasSplit)); + } SchemaElement dimToAdd = SchemaElement.builder() .model(resp.getId()) .id(dim.getId()) @@ -81,17 +83,10 @@ public class ModelSchemaBuilder { .bizName(dim.getBizName()) .type(SchemaElementType.DIMENSION) .useCnt(dim.getUseCnt()) + .alias(alias) .build(); dimensions.add(dimToAdd); - String alias = dim.getAlias(); - if (StringUtils.isNotEmpty(alias)) { - SchemaElement alisDimToAdd = new SchemaElement(); - BeanUtils.copyProperties(dimToAdd, alisDimToAdd); - alisDimToAdd.setName(alias); - dimensions.add(alisDimToAdd); - } - SchemaElement dimValueToAdd = SchemaElement.builder() .model(resp.getId()) .id(dim.getId()) @@ -115,7 +110,7 @@ public class ModelSchemaBuilder { .collect( Collectors.toMap(SchemaElement::getId, schemaElement -> schemaElement, (k1, k2) -> k2)); if (idAndDimPair.containsKey(entity.getEntityId())) { - entityElement = idAndDimPair.get(entity.getEntityId()); + BeanUtils.copyProperties(idAndDimPair.get(entity.getEntityId()), entityElement); entityElement.setType(SchemaElementType.ENTITY); } entityElement.setAlias(entity.getNames()); diff --git a/common/pom.xml b/common/pom.xml index b71ecb23e..e515542ba 100644 --- a/common/pom.xml +++ b/common/pom.xml @@ -118,6 +118,12 @@ + + com.github.plexpt + chatgpt + 4.1.2 + + com.github.pagehelper pagehelper diff --git a/common/src/main/java/com/tencent/supersonic/common/util/ChatGptHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/ChatGptHelper.java new file mode 100644 index 000000000..56f207540 --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/ChatGptHelper.java @@ -0,0 +1,130 @@ +package com.tencent.supersonic.common.util; + + +import com.plexpt.chatgpt.ChatGPT; +import com.plexpt.chatgpt.entity.chat.ChatCompletion; +import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse; +import com.plexpt.chatgpt.entity.chat.Message; +import com.plexpt.chatgpt.util.Proxys; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +import java.net.Proxy; +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.Date; + + +@Component +@Slf4j +public class ChatGptHelper { + + @Value("${llm.chatgpt.apikey:}") + private String apiKey; + + @Value("${llm.chatgpt.apiHost:}") + private String apiHost; + + @Value("${llm.chatgpt.proxyIp:}") + private String proxyIp; + + @Value("${llm.chatgpt.proxyPort:}") + private Integer proxyPort; + + + public ChatGPT getChatGPT(){ + Proxy proxy = null; + if (!"default".equals(proxyIp)){ + proxy = Proxys.http(proxyIp, proxyPort); + } + return ChatGPT.builder() + .apiKey(apiKey) + .proxy(proxy) + .timeout(900) + .apiHost(apiHost) //反向代理地址 + .build() + .init(); + } + + public Message getChatCompletion(Message system,Message message){ + ChatCompletion chatCompletion = ChatCompletion.builder() + .model(ChatCompletion.Model.GPT_3_5_TURBO_16K.getName()) + .messages(Arrays.asList(system, message)) + .maxTokens(10000) + .temperature(0.9) + .build(); + ChatCompletionResponse response = getChatGPT().chatCompletion(chatCompletion); + return response.getChoices().get(0).getMessage(); + } + + public String inferredTime(String queryText){ + long nowTime = System.currentTimeMillis(); + Date date = new Date(nowTime); + SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd"); + String formattedDate = sdf.format(date); + Message system = Message.ofSystem("现在时间 "+formattedDate+",你是一个专业的数据分析师,你的任务是基于数据,专业的解答用户的问题。" + + "你需要遵守以下规则:\n" + + "1.返回规范的数据格式,json,如: 输入:近 10 天的日活跃数,输出:{\"start\":\"2023-07-21\",\"end\":\"2023-07-31\"}" + + "2.你对时间数据要求规范,能从近 10 天,国庆节,端午节,获取到相应的时间,填写到 json 中。\n"+ + "3.你的数据时间,只有当前及之前时间即可,超过则回复去年\n" + + "4.只需要解析出时间,时间可以是时间月和年或日、日历采用公历\n"+ + "5.时间给出要是绝对正确,不能瞎编\n" + ); + Message message = Message.of("输入:"+queryText+",输出:"); + Message res = getChatCompletion(system, message); + return res.getContent(); + } + + + public String mockAlias(String mockType,String name,String bizName,String table,String desc,Boolean isPercentage){ + String msg = "Assuming you are a professional data analyst specializing in indicators, you have a vast amount of data analysis indicator content. You are familiar with the basic format of the content,Now, Construct your answer Based on the following json-schema.\n" + + "{\n" + + "\"$schema\": \"http://json-schema.org/draft-07/schema#\",\n" + + "\"type\": \"array\",\n" + + "\"minItems\": 2,\n" + + "\"maxItems\": 4,\n" + + "\"items\": {\n" + + "\"type\": \"string\",\n" + + "\"description\": \"Assuming you are a data analyst and give a defined "+mockType+" name: " +name+","+ + "this "+mockType+" is from database and table: "+table+ ",This "+mockType+" calculates the field source: "+bizName+", The description of this indicator is: "+desc+", provide some aliases for this,please take chinese or english,but more chinese and Not repeating,\"\n" + + "},\n" + + "\"additionalProperties\":false}\n" + + "Please double-check whether the answer conforms to the format described in the JSON-schema.\n" + + "ANSWER JSON:"; + log.info("msg:{}",msg); + Message system = Message.ofSystem(""); + Message message = Message.of(msg); + Message res = getChatCompletion(system, message); + return res.getContent(); + } + + + + public String mockDimensionValueAlias(String json){ + String msg = "Assuming you are a professional data analyst specializing in indicators,for you a json list," + + "the required content to follow is as follows: " + + "1. The format of JSON," + + "2. Only return in JSON format," + + "3. the array item > 1 and < 5,more alias," + + "for example:input:[\"qq_music\",\"kugou_music\"],out:{\"tran\":[\"qq音乐\",\"酷狗音乐\"],\"alias\":{\"qq_music\":[\"q音\",\"qq音乐\"],\"kugou_music\":[\"kugou\",\"酷狗\"]}}," + + "now input: " + json + ","+ + "answer json:"; + log.info("msg:{}",msg); + Message system = Message.ofSystem(""); + Message message = Message.of(msg); + Message res = getChatCompletion(system, message); + return res.getContent(); + } + + + + + + public static void main(String[] args) { + ChatGptHelper chatGptHelper = new ChatGptHelper(); + System.out.println(chatGptHelper.mockAlias("","","","","",false)); + } + + +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java index 558c3760a..5715b1df9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java @@ -1,11 +1,13 @@ package com.tencent.supersonic.common.util; -import java.text.ParseException; import java.text.SimpleDateFormat; +import java.time.LocalDate; import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; +import java.time.temporal.TemporalAdjusters; import java.util.Calendar; import java.util.Date; +import java.util.Objects; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -56,30 +58,37 @@ public class DateUtils { return dateFormat.format(calendar.getTime()); } - public static String getBeforeDate(String date, int intervalDay, DatePeriodEnum datePeriodEnum) { - Calendar calendar = Calendar.getInstance(); - SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT_DOT); - try { - calendar.setTime(dateFormat.parse(date)); - } catch (ParseException e) { - log.error("parse error"); - } + DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern(DATE_FORMAT_DOT); + LocalDate currentDate = LocalDate.parse(date, dateTimeFormatter); + LocalDate result = null; switch (datePeriodEnum) { case DAY: - calendar.set(Calendar.DATE, calendar.get(Calendar.DATE) - intervalDay); + result = currentDate.minusDays(intervalDay); break; case WEEK: - calendar.set(Calendar.DATE, calendar.get(Calendar.DATE) - intervalDay * 7); + result = currentDate.minusWeeks(intervalDay); + if (intervalDay == 0) { + result = result.with(TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY)); + } break; case MONTH: - calendar.set(Calendar.MONTH, calendar.get(Calendar.MONTH) - intervalDay); + result = currentDate.minusMonths(intervalDay); + if (intervalDay == 0) { + result = result.with(TemporalAdjusters.firstDayOfMonth()); + } break; case YEAR: - calendar.set(Calendar.YEAR, calendar.get(Calendar.YEAR) - intervalDay); + result = currentDate.minusYears(intervalDay); + if (intervalDay == 0) { + result = result.with(TemporalAdjusters.firstDayOfYear()); + } break; default: } - return dateFormat.format(calendar.getTime()); + if (Objects.nonNull(result)) { + return result.format(DateTimeFormatter.ofPattern(DATE_FORMAT_DOT)); + } + return null; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java index 314ec25bd..7aa7c2938 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java @@ -9,7 +9,6 @@ public class StringUtil { public static final String COMMA_WRAPPER = "'%s'"; public static final String SPACE_WRAPPER = " %s "; - public static String getCommaWrap(String value) { return String.format(COMMA_WRAPPER, value); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtils.java index 24cd506bc..da0e9cb1f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtils.java @@ -41,7 +41,9 @@ public class CCJSqlParserUtils { } Set result = new HashSet<>(); Expression where = plainSelect.getWhere(); - where.accept(new FieldAcquireVisitor(result)); + if (Objects.nonNull(where)) { + where.accept(new FieldAcquireVisitor(result)); + } return new ArrayList<>(result); } @@ -166,7 +168,24 @@ public class CCJSqlParserUtils { if (Objects.nonNull(groupByElement)) { groupByElement.accept(new GroupByReplaceVisitor(fieldToBizName)); } - //5. add Waiting Expression + return selectStatement.toString(); + } + + + public static String replaceFunction(String sql) { + Select selectStatement = getSelect(sql); + SelectBody selectBody = selectStatement.getSelectBody(); + if (!(selectBody instanceof PlainSelect)) { + return sql; + } + PlainSelect plainSelect = (PlainSelect) selectBody; + //1. replace where dataDiff function + Expression where = plainSelect.getWhere(); + FunctionReplaceVisitor visitor = new FunctionReplaceVisitor(); + if (Objects.nonNull(where)) { + where.accept(visitor); + } + //2. add Waiting Expression List waitingForAdds = visitor.getWaitingForAdds(); addWaitingExpression(plainSelect, where, waitingForAdds); return selectStatement.toString(); @@ -181,9 +200,10 @@ public class CCJSqlParserUtils { if (where == null) { plainSelect.setWhere(expression); } else { - plainSelect.setWhere(new AndExpression(where, expression)); + where = new AndExpression(where, expression); } } + plainSelect.setWhere(where); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java index b41a1164d..2ee2bd317 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FieldReplaceVisitor.java @@ -1,25 +1,15 @@ package com.tencent.supersonic.common.util.jsqlparser; -import java.util.ArrayList; -import java.util.List; import java.util.Map; -import java.util.Objects; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; -import net.sf.jsqlparser.expression.operators.relational.GreaterThan; -import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; -import net.sf.jsqlparser.expression.operators.relational.MinorThan; -import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.schema.Column; @Slf4j public class FieldReplaceVisitor extends ExpressionVisitorAdapter { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); - private Map fieldToBizName; - private List waitingForAdds = new ArrayList<>(); public FieldReplaceVisitor(Map fieldToBizName) { this.fieldToBizName = fieldToBizName; @@ -29,42 +19,4 @@ public class FieldReplaceVisitor extends ExpressionVisitorAdapter { public void visit(Column column) { parseVisitorHelper.replaceColumn(column, fieldToBizName); } - - @Override - public void visit(MinorThan expr) { - Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, ">"); - if (Objects.nonNull(expression)) { - waitingForAdds.add(expression); - } - } - - @Override - public void visit(MinorThanEquals expr) { - Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, ">="); - if (Objects.nonNull(expression)) { - waitingForAdds.add(expression); - } - } - - - @Override - public void visit(GreaterThan expr) { - Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, "<"); - if (Objects.nonNull(expression)) { - waitingForAdds.add(expression); - } - } - - @Override - public void visit(GreaterThanEquals expr) { - Expression expression = parseVisitorHelper.reparseDate(expr, fieldToBizName, "<="); - if (Objects.nonNull(expression)) { - waitingForAdds.add(expression); - } - } - - public List getWaitingForAdds() { - return waitingForAdds; - } - } \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java new file mode 100644 index 000000000..34b8e608d --- /dev/null +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/FunctionReplaceVisitor.java @@ -0,0 +1,171 @@ +package com.tencent.supersonic.common.util.jsqlparser; + +import com.tencent.supersonic.common.util.DatePeriodEnum; +import com.tencent.supersonic.common.util.DateUtils; +import com.tencent.supersonic.common.util.StringUtil; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.DoubleValue; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.ExpressionVisitorAdapter; +import net.sf.jsqlparser.expression.Function; +import net.sf.jsqlparser.expression.LongValue; +import net.sf.jsqlparser.expression.StringValue; +import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; +import net.sf.jsqlparser.expression.operators.relational.EqualsTo; +import net.sf.jsqlparser.expression.operators.relational.GreaterThan; +import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals; +import net.sf.jsqlparser.expression.operators.relational.MinorThan; +import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import net.sf.jsqlparser.schema.Column; +import org.apache.commons.collections.CollectionUtils; + +@Slf4j +public class FunctionReplaceVisitor extends ExpressionVisitorAdapter { + + public static final String DATE_FUNCTION = "datediff"; + public static final double HALF_YEAR = 0.5d; + public static final int SIX_MONTH = 6; + public static final String EQUAL = "="; + private List waitingForAdds = new ArrayList<>(); + + @Override + public void visit(MinorThan expr) { + List expressions = reparseDate(expr, ">"); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(EqualsTo expr) { + List expressions = reparseDate(expr, ">="); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(MinorThanEquals expr) { + List expressions = reparseDate(expr, ">="); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + + @Override + public void visit(GreaterThan expr) { + List expressions = reparseDate(expr, "<"); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + @Override + public void visit(GreaterThanEquals expr) { + List expressions = reparseDate(expr, "<="); + if (Objects.nonNull(expressions)) { + waitingForAdds.addAll(expressions); + } + } + + public List getWaitingForAdds() { + return waitingForAdds; + } + + + public List reparseDate(ComparisonOperator comparisonOperator, String startDateOperator) { + List result = new ArrayList<>(); + Expression leftExpression = comparisonOperator.getLeftExpression(); + if (!(leftExpression instanceof Function)) { + return result; + } + Function leftExpressionFunction = (Function) leftExpression; + if (!leftExpressionFunction.toString().contains(DATE_FUNCTION)) { + return result; + } + List leftExpressions = leftExpressionFunction.getParameters().getExpressions(); + if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) { + return result; + } + Column field = (Column) leftExpressions.get(1); + String columnName = field.getColumnName(); + try { + String startDateValue = getStartDateStr(comparisonOperator, leftExpressions); + String endDateValue = getEndDateValue(leftExpressions); + String endDateOperator = comparisonOperator.getStringExpression(); + String condExpr = + columnName + StringUtil.getSpaceWrap(getEndDateOperator(comparisonOperator)) + + StringUtil.getCommaWrap(endDateValue); + ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr); + + String startDataCondExpr = + columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue); + + if (EQUAL.equalsIgnoreCase(endDateOperator)) { + result.add(CCJSqlParserUtil.parseCondExpression(condExpr)); + expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(" 1 = 1 "); + } + comparisonOperator.setLeftExpression(null); + comparisonOperator.setRightExpression(null); + comparisonOperator.setASTNode(null); + + comparisonOperator.setLeftExpression(expression.getLeftExpression()); + comparisonOperator.setRightExpression(expression.getRightExpression()); + comparisonOperator.setASTNode(expression.getASTNode()); + + result.add(CCJSqlParserUtil.parseCondExpression(startDataCondExpr)); + return result; + } catch (JSQLParserException e) { + log.error("JSQLParserException", e); + } + + return null; + } + + private String getStartDateStr(ComparisonOperator minorThanEquals, List expressions) { + String unitValue = getUnit(expressions); + String dateValue = getEndDateValue(expressions); + String dateStr = ""; + Expression rightExpression = minorThanEquals.getRightExpression(); + DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(unitValue); + if (rightExpression instanceof DoubleValue) { + DoubleValue value = (DoubleValue) rightExpression; + double doubleValue = value.getValue(); + if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) { + datePeriodEnum = DatePeriodEnum.MONTH; + dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum); + } + } else if (rightExpression instanceof LongValue) { + LongValue value = (LongValue) rightExpression; + long doubleValue = value.getValue(); + dateStr = DateUtils.getBeforeDate(dateValue, (int) doubleValue, datePeriodEnum); + } + return dateStr; + } + + private String getEndDateOperator(ComparisonOperator comparisonOperator) { + String operator = comparisonOperator.getStringExpression(); + if (EQUAL.equalsIgnoreCase(operator)) { + operator = "<="; + } + return operator; + } + + private String getEndDateValue(List leftExpressions) { + StringValue date = (StringValue) leftExpressions.get(2); + return date.getValue(); + } + + private String getUnit(List expressions) { + StringValue unit = (StringValue) expressions.get(0); + return unit.getValue(); + } + + +} \ No newline at end of file diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java index d7c545b27..2971a64a6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/GroupByReplaceVisitor.java @@ -2,13 +2,18 @@ package com.tencent.supersonic.common.util.jsqlparser; import java.util.List; import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.expression.Function; import net.sf.jsqlparser.expression.operators.relational.ExpressionList; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.GroupByVisitor; import org.apache.commons.lang3.StringUtils; +@Slf4j public class GroupByReplaceVisitor implements GroupByVisitor { ParseVisitorHelper parseVisitorHelper = new ParseVisitorHelper(); @@ -29,7 +34,17 @@ public class GroupByReplaceVisitor implements GroupByVisitor { String replaceColumn = parseVisitorHelper.getReplaceColumn(expression.toString(), fieldToBizName); if (StringUtils.isNotEmpty(replaceColumn)) { - groupByExpressions.set(i, new Column(replaceColumn)); + if (expression instanceof Column) { + groupByExpressions.set(i, new Column(replaceColumn)); + } + if (expression instanceof Function) { + try { + Expression element = CCJSqlParserUtil.parseExpression(replaceColumn); + ((Function) expression).getParameters().getExpressions().set(0, element); + } catch (JSQLParserException e) { + log.error("e", e); + } + } } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java index cd943ea4d..fde04277d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/ParseVisitorHelper.java @@ -1,33 +1,16 @@ package com.tencent.supersonic.common.util.jsqlparser; -import com.tencent.supersonic.common.util.DatePeriodEnum; -import com.tencent.supersonic.common.util.DateUtils; -import com.tencent.supersonic.common.util.StringUtil; -import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; -import net.sf.jsqlparser.expression.DoubleValue; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.expression.Function; -import net.sf.jsqlparser.expression.LongValue; -import net.sf.jsqlparser.expression.StringValue; -import net.sf.jsqlparser.expression.operators.relational.ComparisonOperator; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; -import org.apache.commons.collections.CollectionUtils; import org.apache.commons.lang3.StringUtils; @Slf4j public class ParseVisitorHelper { - public static final double HALF_YEAR = 0.5d; - public static final int SIX_MONTH = 6; - public static final String DATE_FUNCTION = "datediff"; - public void replaceColumn(Column column, Map fieldToBizName) { String columnName = column.getColumnName(); column.setColumnName(getReplaceColumn(columnName, fieldToBizName)); @@ -53,88 +36,6 @@ public class ParseVisitorHelper { return columnName; } - public Expression reparseDate(ComparisonOperator comparisonOperator, Map fieldToBizName, - String startDateOperator) { - Expression leftExpression = comparisonOperator.getLeftExpression(); - if (leftExpression instanceof Column) { - Column leftExpressionColumn = (Column) leftExpression; - replaceColumn(leftExpressionColumn, fieldToBizName); - return null; - } - - if (!(leftExpression instanceof Function)) { - return null; - } - Function leftExpressionFunction = (Function) leftExpression; - if (!leftExpressionFunction.toString().contains(DATE_FUNCTION)) { - return null; - } - List leftExpressions = leftExpressionFunction.getParameters().getExpressions(); - if (CollectionUtils.isEmpty(leftExpressions) || leftExpressions.size() < 3) { - return null; - } - Column field = (Column) leftExpressions.get(1); - String columnName = field.getColumnName(); - String startDateValue = getStartDateStr(comparisonOperator, leftExpressions); - String fieldBizName = fieldToBizName.get(columnName); - try { - String endDateValue = getEndDateValue(leftExpressions); - String stringExpression = comparisonOperator.getStringExpression(); - - String condExpr = - fieldBizName + StringUtil.getSpaceWrap(stringExpression) + StringUtil.getCommaWrap(endDateValue); - ComparisonOperator expression = (ComparisonOperator) CCJSqlParserUtil.parseCondExpression(condExpr); - - comparisonOperator.setLeftExpression(null); - comparisonOperator.setRightExpression(null); - comparisonOperator.setASTNode(null); - - comparisonOperator.setLeftExpression(expression.getLeftExpression()); - comparisonOperator.setRightExpression(expression.getRightExpression()); - comparisonOperator.setASTNode(expression.getASTNode()); - - String startDataCondExpr = - fieldBizName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue); - return CCJSqlParserUtil.parseCondExpression(startDataCondExpr); - } catch (JSQLParserException e) { - log.error("JSQLParserException", e); - } - - return null; - } - - private String getEndDateValue(List leftExpressions) { - StringValue date = (StringValue) leftExpressions.get(2); - return date.getValue(); - } - - private String getStartDateStr(ComparisonOperator minorThanEquals, List expressions) { - String unitValue = getUnit(expressions); - String dateValue = getEndDateValue(expressions); - String dateStr = ""; - Expression rightExpression = minorThanEquals.getRightExpression(); - DatePeriodEnum datePeriodEnum = DatePeriodEnum.get(unitValue); - if (rightExpression instanceof DoubleValue) { - DoubleValue value = (DoubleValue) rightExpression; - double doubleValue = value.getValue(); - if (DatePeriodEnum.YEAR.equals(datePeriodEnum) && doubleValue == HALF_YEAR) { - datePeriodEnum = DatePeriodEnum.MONTH; - dateStr = DateUtils.getBeforeDate(dateValue, SIX_MONTH, datePeriodEnum); - } - } else if (rightExpression instanceof LongValue) { - LongValue value = (LongValue) rightExpression; - long doubleValue = value.getValue(); - dateStr = DateUtils.getBeforeDate(dateValue, (int) doubleValue, datePeriodEnum); - } - return dateStr; - } - - private String getUnit(List expressions) { - StringValue unit = (StringValue) expressions.get(0); - return unit.getValue(); - } - - public static int editDistance(String word1, String word2) { final int m = word1.length(); final int n = word2.length(); diff --git a/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java b/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java index 54e2bd5b4..b4f9f3211 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/DateUtilsTest.java @@ -14,13 +14,25 @@ class DateUtilsTest { dateStr = DateUtils.getBeforeDate("2023-08-10", 8, DatePeriodEnum.DAY); Assert.assertEquals(dateStr, "2023-08-02"); + dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.DAY); + Assert.assertEquals(dateStr, "2023-08-10"); + dateStr = DateUtils.getBeforeDate("2023-08-10", 1, DatePeriodEnum.WEEK); Assert.assertEquals(dateStr, "2023-08-03"); + dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.WEEK); + Assert.assertEquals(dateStr, "2023-08-07"); + dateStr = DateUtils.getBeforeDate("2023-08-01", 1, DatePeriodEnum.MONTH); Assert.assertEquals(dateStr, "2023-07-01"); + dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.MONTH); + Assert.assertEquals(dateStr, "2023-08-01"); + dateStr = DateUtils.getBeforeDate("2023-08-01", 1, DatePeriodEnum.YEAR); Assert.assertEquals(dateStr, "2022-08-01"); + + dateStr = DateUtils.getBeforeDate("2023-08-10", 0, DatePeriodEnum.YEAR); + Assert.assertEquals(dateStr, "2023-01-01"); } } \ No newline at end of file diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtilsTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtilsTest.java index d38bc3678..427c16471 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtilsTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/CCJSqlParserUtilsTest.java @@ -22,37 +22,104 @@ class CCJSqlParserUtilsTest { String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11"; replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND song_publis_date = '2023-08-01' AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11" + , replaceSql); + + replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)"; + + replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' GROUP BY YEAR(publish_date)", + replaceSql); + + replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' group by 发行日期"; + + replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT YEAR(publish_date), count(song_name) FROM 歌曲库 WHERE YEAR(publish_date) IN (2022, 2023) AND sys_imp_date = '2023-08-14' GROUP BY publish_date", + replaceSql); + + replaceSql = CCJSqlParserUtils.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30", + fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-11' AND play_count > 1000000 AND sys_imp_date <= '2023-08-11' AND publish_date >= '2022-08-11' AND sys_imp_date >= '2023-07-12'" + , replaceSql); replaceSql = CCJSqlParserUtils.replaceFields( "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date >= '2023-08-08' ORDER BY play_count DESC LIMIT 11" + , replaceSql); + + replaceSql = CCJSqlParserUtils.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", + fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE 1 = 1 AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-08-09' AND publish_date >= '2023-01-01' ORDER BY play_count DESC LIMIT 11" + , replaceSql); replaceSql = CCJSqlParserUtils.replaceFields( "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date >= '2023-02-09' ORDER BY play_count DESC LIMIT 11" + , replaceSql); replaceSql = CCJSqlParserUtils.replaceFields( "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT song_name FROM 歌曲库 WHERE publish_date >= '2023-08-09' AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' AND publish_date <= '2023-02-09' ORDER BY play_count DESC LIMIT 11" + , replaceSql); replaceSql = CCJSqlParserUtils.replaceFields( - "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 =alice and 发布日期 ='11' order by 访问次数 desc limit 1", + "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice' and 发布日期 ='11' order by 访问次数 desc limit 1", fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT department, user_id FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1" + , replaceSql); replaceSql = CCJSqlParserUtils.replaceTable(replaceSql, "s2"); replaceSql = CCJSqlParserUtils.addFieldsToSelect(replaceSql, Collections.singletonList("field_a")); replaceSql = CCJSqlParserUtils.replaceFields( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 =alice and 发布日期 ='11' group by 部门 limit 1", + "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice' and 发布日期 ='11' group by 部门 limit 1", fieldToBizName); - Assert.assertEquals(replaceSql, - "SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = user_id AND publish_date = '11' GROUP BY department LIMIT 1"); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); + + Assert.assertEquals( + "SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1", + replaceSql); replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' and 数据日期 <= '2023-08-06' and 部门 = 'hr'"; replaceSql = CCJSqlParserUtils.replaceFields(replaceSql, fieldToBizName); + replaceSql = CCJSqlParserUtils.replaceFunction(replaceSql); - System.out.println(replaceSql); + Assert.assertEquals( + "SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' AND sys_imp_date <= '2023-08-06' AND department = 'hr'", + replaceSql); } diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index 1c807eb8c..c1c44be25 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ com.tencent.supersonic.chat.api.component.SemanticParser=\ com.tencent.supersonic.chat.parser.rule.QueryModeParser, \ com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \ + com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \ com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \ com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \ - com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \ + com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \ com.tencent.supersonic.chat.parser.function.FunctionBasedParser com.tencent.supersonic.chat.api.component.SemanticLayer=\ com.tencent.supersonic.knowledge.semantic.RemoteSemanticLayer @@ -19,4 +20,13 @@ com.tencent.supersonic.chat.parser.function.ModelResolver=\ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor=\ com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ - com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor \ No newline at end of file + com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor + + +com.tencent.supersonic.chat.api.component.DSLOptimizer=\ + com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \ + com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector \ No newline at end of file diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java index 1fc145c90..00e662cf0 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ConfigureDemo.java @@ -1,6 +1,12 @@ package com.tencent.supersonic; +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.agent.AgentConfig; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; import com.tencent.supersonic.chat.api.pojo.request.ChatAggConfigReq; import com.tencent.supersonic.chat.api.pojo.request.ChatConfigBaseReq; import com.tencent.supersonic.chat.api.pojo.request.ChatDefaultConfigReq; @@ -14,16 +20,14 @@ import com.tencent.supersonic.chat.plugin.Plugin; import com.tencent.supersonic.chat.plugin.PluginParseConfig; import com.tencent.supersonic.chat.query.plugin.ParamOption; import com.tencent.supersonic.chat.query.plugin.WebBase; -import com.tencent.supersonic.chat.service.ChatService; -import com.tencent.supersonic.chat.service.ConfigService; -import com.tencent.supersonic.chat.service.PluginService; -import com.tencent.supersonic.chat.service.QueryService; +import com.tencent.supersonic.chat.service.*; import com.tencent.supersonic.common.util.JsonUtil; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.context.event.ApplicationReadyEvent; import org.springframework.context.ApplicationListener; import org.springframework.stereotype.Component; @@ -32,6 +36,7 @@ import org.springframework.stereotype.Component; @Slf4j public class ConfigureDemo implements ApplicationListener { + @Qualifier("chatQueryService") @Autowired private QueryService queryService; @Autowired @@ -40,6 +45,8 @@ public class ConfigureDemo implements ApplicationListener protected ConfigService configService; @Autowired private PluginService pluginService; + @Autowired + private AgentService agentService; private User user = User.getFakeUser(); @@ -175,43 +182,25 @@ public class ConfigureDemo implements ApplicationListener pluginService.createPlugin(plugin_1, user); } - private void addPlugin_2() { - Plugin plugin_2 = new Plugin(); - plugin_2.setType("DSL"); - plugin_2.setModelList(Arrays.asList(1L, 2L)); - plugin_2.setPattern(""); - plugin_2.setParseModeConfig(null); - plugin_2.setName("大模型语义解析"); - List examples = new ArrayList<>(); - examples.add("超音数访问次数最高的部门是哪个"); - examples.add("超音数访问人数最高的部门是哪个"); - PluginParseConfig parseConfig = PluginParseConfig.builder() - .name("DSL") - .description("这个工具能够将用户的自然语言查询转化为SQL语句,从而从数据库中的查询具体的数据。用于处理数据查询的问题,提供基于事实的数据") - .examples(examples) - .build(); - plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig)); - pluginService.createPlugin(plugin_2, user); - } - - private void addPlugin_3() { - Plugin plugin_2 = new Plugin(); - plugin_2.setType("CONTENT_INTERPRET"); - plugin_2.setModelList(Arrays.asList(1L)); - plugin_2.setPattern("超音数最近访问情况怎么样"); - plugin_2.setParseModeConfig(null); - plugin_2.setName("内容解读"); - List examples = new ArrayList<>(); - examples.add("超音数最近访问情况怎么样"); - examples.add("超音数最近访问情况如何"); - PluginParseConfig parseConfig = PluginParseConfig.builder() - .name("supersonic_content_interpret") - .description("这个工具能够先查询到相关的数据并交给大模型进行解读, 最后返回解读结果") - .examples(examples) - .build(); - plugin_2.setParseModeConfig(JsonUtil.toString(parseConfig)); - pluginService.createPlugin(plugin_2, user); + private void addAgent() { + Agent agent = new Agent(); + agent.setId(1); + agent.setName("查信息"); + agent.setDescription("查信息"); + agent.setStatus(1); + agent.setEnableSearch(1); + agent.setExamples(Lists.newArrayList("超音数访问次数", "超音数访问人数", "alice 停留时长")); + AgentConfig agentConfig = new AgentConfig(); + RuleQueryTool ruleQueryTool = new RuleQueryTool(); + ruleQueryTool.setType(AgentToolType.RULE); + ruleQueryTool.setQueryModes(Lists.newArrayList( + "ENTITY_DETAIL", "ENTITY_LIST_FILTER", "ENTITY_ID", "METRIC_ENTITY", + "METRIC_FILTER", "METRIC_GROUPBY", "METRIC_MODEL", "METRIC_ORDERBY" + )); + agentConfig.getTools().add(ruleQueryTool); + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + agentService.createAgent(agent, User.getFakeUser()); } @Override @@ -220,8 +209,7 @@ public class ConfigureDemo implements ApplicationListener addDemoChatConfig_1(); addDemoChatConfig_2(); addPlugin_1(); - addPlugin_2(); - addPlugin_3(); + addAgent(); addSampleChats(); addSampleChats2(); } catch (Exception e) { diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index 95f2dbbb4..67aecd08c 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -6,9 +6,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ com.tencent.supersonic.chat.api.component.SemanticParser=\ com.tencent.supersonic.chat.parser.rule.QueryModeParser, \ com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \ + com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \ com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \ com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \ - com.tencent.supersonic.chat.parser.llm.LLMDSLParser, \ + com.tencent.supersonic.chat.parser.llm.dsl.LLMDSLParser, \ com.tencent.supersonic.chat.parser.embedding.EmbeddingBasedParser, \ com.tencent.supersonic.chat.parser.function.FunctionBasedParser com.tencent.supersonic.chat.api.component.SemanticLayer=\ @@ -21,3 +22,11 @@ com.tencent.supersonic.auth.authentication.interceptor.AuthenticationInterceptor com.tencent.supersonic.auth.authentication.interceptor.DefaultAuthenticationInterceptor com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor + +com.tencent.supersonic.chat.api.component.DSLOptimizer=\ + com.tencent.supersonic.chat.query.dsl.optimizer.DateFieldCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.FieldCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.FunctionCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.TableNameCorrector, \ + com.tencent.supersonic.chat.query.dsl.optimizer.QueryFilterAppend, \ + com.tencent.supersonic.chat.query.dsl.optimizer.SelectFieldAppendCorrector \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 03f9e50f2..9fa1ec519 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -648,6 +648,22 @@ CREATE TABLE IF NOT EXISTS `s2_plugin` COMMENT ON TABLE s2_plugin IS 'plugin information table'; +CREATE TABLE IF NOT EXISTS s2_agent +( + id int AUTO_INCREMENT, + name varchar(100) null, + description varchar(500) null, + status int null, + examples varchar(500) null, + config varchar(2000) null, + created_by varchar(100) null, + created_at TIMESTAMP null, + updated_by varchar(100) null, + updated_at TIMESTAMP null, + enable_search int null, + PRIMARY KEY (`id`) +); COMMENT ON TABLE s2_agent IS 'assistant information table'; + -------demo for semantic and chat CREATE TABLE IF NOT EXISTS `s2_user_department` diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java index bc6ebe11d..7b8d5f49b 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/BaseQueryTest.java @@ -11,6 +11,7 @@ import com.tencent.supersonic.chat.api.pojo.request.QueryReq; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.ConfigService; import com.tencent.supersonic.chat.service.QueryService; @@ -22,6 +23,7 @@ import org.junit.runner.RunWith; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.test.context.ActiveProfiles; import org.springframework.test.context.junit4.SpringRunner; @@ -42,6 +44,8 @@ public class BaseQueryTest { protected ChatService chatService; @Autowired protected ConfigService configService; + @MockBean + protected AgentService agentService; protected QueryResult submitMultiTurnChat(String queryText) throws Exception { ParseResp parseResp = submitParse(queryText); @@ -78,6 +82,11 @@ public class BaseQueryTest { return queryService.performParsing(queryContextReq); } + protected ParseResp submitParseWithAgent(String queryText, Integer agentId) { + QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(10, queryText, agentId); + return queryService.performParsing(queryContextReq); + } + protected void assertSchemaElements(Set expected, Set actual) { Set expectedNames = expected.stream().map(s -> s.getName()) .filter(s -> s != null).collect(Collectors.toSet()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java new file mode 100644 index 000000000..a1da76abd --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricInterpretTest.java @@ -0,0 +1,56 @@ +package com.tencent.supersonic.integration; + +import com.alibaba.fastjson.JSONObject; +import com.tencent.supersonic.StandaloneLauncher; +import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.QueryResult; +import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig; +import com.tencent.supersonic.chat.plugin.PluginManager; +import com.tencent.supersonic.chat.query.metricInterpret.LLmAnswerResp; +import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.chat.service.QueryService; +import com.tencent.supersonic.util.DataUtils; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.http.ResponseEntity; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.test.context.junit4.SpringRunner; + +@RunWith(SpringRunner.class) +@SpringBootTest(classes = StandaloneLauncher.class) +@ActiveProfiles("local") +public class MetricInterpretTest { + + @MockBean + private AgentService agentService; + + @MockBean + private PluginManager pluginManager; + + @MockBean + private EmbeddingConfig embeddingConfig; + + @Autowired + @Qualifier("chatQueryService") + private QueryService queryService; + + @Test + public void testMetricInterpret() throws Exception { + MockConfiguration.mockAgent(agentService); + MockConfiguration.mockEmbeddingUrl(embeddingConfig); + LLmAnswerResp lLmAnswerResp = new LLmAnswerResp(); + lLmAnswerResp.setAssistant_message("alice最近在超音数的访问情况有增多"); + MockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call", + ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp))); + QueryReq queryReq = DataUtils.getQueryReqWithAgent(1000, "能不能帮我解读分析下最近alice在超音数的访问情况", + DataUtils.getAgent().getId()); + QueryResult queryResult = queryService.executeQuery(queryReq); + Assert.assertEquals(queryResult.getQueryResults().get(0).get("answer"), lLmAnswerResp.getAssistant_message()); + } + +} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java index 941360dc6..f34309d93 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java @@ -1,30 +1,31 @@ package com.tencent.supersonic.integration; -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.NONE; -import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM; - import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq; -import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; -import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp; +import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; +import com.tencent.supersonic.chat.api.pojo.request.ChatConfigEditReqReq; +import com.tencent.supersonic.chat.api.pojo.response.ChatConfigResp; +import com.tencent.supersonic.chat.api.pojo.request.ItemVisibility; +import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricFilterQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricGroupByQuery; -import com.tencent.supersonic.chat.query.rule.metric.MetricModelQuery; import com.tencent.supersonic.chat.query.rule.metric.MetricTopNQuery; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.util.DataUtils; -import java.text.DateFormat; -import java.text.SimpleDateFormat; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import org.junit.Assert; import org.junit.Test; import org.springframework.beans.BeanUtils; +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.*; +import java.util.stream.Collectors; + +import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.*; + public class MetricQueryTest extends BaseQueryTest { @@ -50,6 +51,17 @@ public class MetricQueryTest extends BaseQueryTest { assertQueryResult(expectedResult, actualResult); } + @Test + public void queryTest_METRIC_FILTER_with_agent() { + //agent only support METRIC_ENTITY, METRIC_FILTER + MockConfiguration.mockAgent(agentService); + ParseResp parseResp = submitParseWithAgent("alice的访问次数", DataUtils.getAgent().getId()); + Assert.assertNotNull(parseResp.getSelectedParses()); + List queryModes = parseResp.getSelectedParses().stream() + .map(SemanticParseInfo::getQueryMode).collect(Collectors.toList()); + Assert.assertTrue(queryModes.contains("METRIC_FILTER")); + } + @Test public void queryTest_METRIC_DOMAIN() throws Exception { QueryResult actualResult = submitNewChat("超音数的访问次数"); @@ -69,6 +81,16 @@ public class MetricQueryTest extends BaseQueryTest { assertQueryResult(expectedResult, actualResult); } + @Test + public void queryTest_METRIC_MODEL_with_agent() { + //agent only support METRIC_ENTITY, METRIC_FILTER + MockConfiguration.mockAgent(agentService); + ParseResp parseResp = submitParseWithAgent("超音数的访问次数", DataUtils.getAgent().getId()); + List queryModes = parseResp.getSelectedParses().stream() + .map(SemanticParseInfo::getQueryMode).collect(Collectors.toList()); + Assert.assertTrue(queryModes.contains("METRIC_MODEL")); + } + @Test public void queryTest_METRIC_GROUPBY() throws Exception { QueryResult actualResult = submitNewChat("超音数各部门的访问次数"); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginMockConfiguration.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java similarity index 79% rename from launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginMockConfiguration.java rename to launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java index 076861c86..fb95fe332 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginMockConfiguration.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MockConfiguration.java @@ -1,22 +1,23 @@ -package com.tencent.supersonic.integration.plugin; +package com.tencent.supersonic.integration; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.notNull; -import static org.mockito.Mockito.when; - import com.google.common.collect.Lists; import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig; import com.tencent.supersonic.chat.parser.embedding.EmbeddingResp; import com.tencent.supersonic.chat.parser.embedding.RecallRetrieval; import com.tencent.supersonic.chat.plugin.PluginManager; +import com.tencent.supersonic.chat.service.AgentService; +import com.tencent.supersonic.util.DataUtils; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Configuration; import org.springframework.http.ResponseEntity; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.when; @Configuration @Slf4j -public class PluginMockConfiguration { +public class MockConfiguration { public static void mockEmbeddingRecognize(PluginManager pluginManager, String text, String id) { EmbeddingResp embeddingResp = new EmbeddingResp(); @@ -33,9 +34,12 @@ public class PluginMockConfiguration { when(embeddingConfig.getUrl()).thenReturn("test"); } - public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, - ResponseEntity responseEntity) { + public static void mockPluginManagerDoRequest(PluginManager pluginManager, String path, ResponseEntity responseEntity) { when(pluginManager.doRequest(eq(path), notNull(String.class))).thenReturn(responseEntity); } + public static void mockAgent(AgentService agentService) { + when(agentService.getAgent(1)).thenReturn(DataUtils.getAgent()); + } + } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java index be8a03bf3..56836c7a1 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/plugin/PluginRecognizeTest.java @@ -1,23 +1,23 @@ package com.tencent.supersonic.integration.plugin; -import com.alibaba.fastjson.JSONObject; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.parser.embedding.EmbeddingConfig; import com.tencent.supersonic.chat.plugin.PluginManager; -import com.tencent.supersonic.chat.query.ContentInterpret.LLmAnswerResp; +import com.tencent.supersonic.chat.service.AgentService; import com.tencent.supersonic.chat.service.QueryService; +import com.tencent.supersonic.integration.MockConfiguration; import com.tencent.supersonic.util.DataUtils; import org.junit.Assert; import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.test.mock.mockito.MockBean; -import org.springframework.http.ResponseEntity; -public class PluginRecognizeTest extends BasePluginTest { +public class PluginRecognizeTest extends BasePluginTest{ @MockBean private EmbeddingConfig embeddingConfig; @@ -25,14 +25,17 @@ public class PluginRecognizeTest extends BasePluginTest { @MockBean protected PluginManager pluginManager; + @MockBean + protected AgentService agentService; + @Autowired @Qualifier("chatQueryService") private QueryService queryService; @Test public void webPageRecognize() throws Exception { - PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样", "1"); - PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig); + MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1"); + MockConfiguration.mockEmbeddingUrl(embeddingConfig); QueryReq queryContextReq = DataUtils.getQueryContextReq(1000, "alice最近的访问情况怎么样"); QueryResult queryResult = queryService.executeQuery(queryContextReq); assertPluginRecognizeResult(queryResult); @@ -40,8 +43,8 @@ public class PluginRecognizeTest extends BasePluginTest { @Test public void webPageRecognizeWithQueryFilter() throws Exception { - PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样", "1"); - PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig); + MockConfiguration.mockEmbeddingRecognize(pluginManager, "在超音数最近的情况怎么样","1"); + MockConfiguration.mockEmbeddingUrl(embeddingConfig); QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "在超音数最近的情况怎么样"); QueryFilters queryFilters = new QueryFilters(); QueryFilter queryFilter = new QueryFilter(); @@ -55,17 +58,15 @@ public class PluginRecognizeTest extends BasePluginTest { } @Test - public void contentInterpretRecognize() throws Exception { - PluginMockConfiguration.mockEmbeddingRecognize(pluginManager, "超音数最近访问情况怎么样", "3"); - PluginMockConfiguration.mockEmbeddingUrl(embeddingConfig); - LLmAnswerResp lLmAnswerResp = new LLmAnswerResp(); - lLmAnswerResp.setAssistant_message("超音数最近访问情况不错"); - PluginMockConfiguration.mockPluginManagerDoRequest(pluginManager, "answer_with_plugin_call", - ResponseEntity.ok(JSONObject.toJSONString(lLmAnswerResp))); - QueryReq queryRequest = DataUtils.getQueryContextReq(1000, "超音数最近访问情况怎么样"); - QueryResult queryResult = queryService.executeQuery(queryRequest); - Assert.assertEquals(queryResult.getResponse(), lLmAnswerResp.getAssistant_message()); - System.out.println(); + public void pluginRecognizeWithAgent() { + MockConfiguration.mockEmbeddingRecognize(pluginManager, "alice最近的访问情况怎么样","1"); + MockConfiguration.mockEmbeddingUrl(embeddingConfig); + MockConfiguration.mockAgent(agentService); + QueryReq queryContextReq = DataUtils.getQueryReqWithAgent(1000, "alice最近的访问情况怎么样", + DataUtils.getAgent().getId()); + ParseResp parseResp = queryService.performParsing(queryContextReq); + Assert.assertTrue(parseResp.getSelectedParses() != null + && parseResp.getSelectedParses().size() > 0); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index dfa463e59..d478b531d 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -1,16 +1,26 @@ package com.tencent.supersonic.util; -import static java.time.LocalDate.now; - +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.chat.agent.Agent; +import com.tencent.supersonic.chat.agent.AgentConfig; +import com.tencent.supersonic.chat.agent.tool.AgentToolType; +import com.tencent.supersonic.chat.agent.tool.MetricInterpretTool; +import com.tencent.supersonic.chat.agent.tool.PluginTool; +import com.tencent.supersonic.chat.agent.tool.RuleQueryTool; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaElementType; import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; +import com.tencent.supersonic.chat.parser.llm.interpret.MetricOption; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; + import java.util.Set; +import static java.time.LocalDate.now; + public class DataUtils { private static final User user_test = new User(1L, "admin", "admin", "admin@email"); @@ -27,6 +37,15 @@ public class DataUtils { return queryContextReq; } + public static QueryReq getQueryReqWithAgent(Integer id, String query, Integer agentId) { + QueryReq queryReq = new QueryReq(); + queryReq.setQueryText(query);//"alice的访问次数" + queryReq.setChatId(id); + queryReq.setUser(user_test); + queryReq.setAgentId(agentId); + return queryReq; + } + public static SchemaElement getSchemaElement(String name) { return SchemaElement.builder() .name(name) @@ -55,9 +74,8 @@ public class DataUtils { .build(); } - public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value, - String name, - Long elementId) { + public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, Object value, String name, + Long elementId) { QueryFilter filter = new QueryFilter(); filter.setBizName(bizName); filter.setOperator(filterOperatorEnum); @@ -77,8 +95,7 @@ public class DataUtils { return dateInfo; } - public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate, - String endDate) { + public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, String period, String startDate, String endDate) { DateConf dateInfo = new DateConf(); dateInfo.setUnit(unit); dateInfo.setDateMode(dateMode); @@ -129,4 +146,43 @@ public class DataUtils { return dimensionFilterExist; } + + public static Agent getAgent() { + Agent agent = new Agent(); + agent.setId(1); + agent.setName("查信息"); + agent.setDescription("查信息"); + AgentConfig agentConfig = new AgentConfig(); + agentConfig.getTools().add(getRuleQueryTool()); + agentConfig.getTools().add(getPluginTool()); + agentConfig.getTools().add(getMetricInterpretTool()); + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + return agent; + } + + private static RuleQueryTool getRuleQueryTool() { + RuleQueryTool ruleQueryTool = new RuleQueryTool(); + ruleQueryTool.setType(AgentToolType.RULE); + ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ENTITY", "METRIC_FILTER", "METRIC_MODEL")); + return ruleQueryTool; + } + + private static PluginTool getPluginTool() { + PluginTool pluginTool = new PluginTool(); + pluginTool.setType(AgentToolType.PLUGIN); + pluginTool.setPlugins(Lists.newArrayList(1L)); + return pluginTool; + } + + private static MetricInterpretTool getMetricInterpretTool() { + MetricInterpretTool metricInterpretTool = new MetricInterpretTool(); + metricInterpretTool.setModelId(1L); + metricInterpretTool.setType(AgentToolType.INTERPRET); + metricInterpretTool.setMetricOptions(Lists.newArrayList( + new MetricOption(1L), + new MetricOption(2L), + new MetricOption(3L))); + return metricInterpretTool; + } + } diff --git a/launchers/standalone/src/test/resources/META-INF/spring.factories b/launchers/standalone/src/test/resources/META-INF/spring.factories index 1f2e4f3d3..a7c81c6cc 100644 --- a/launchers/standalone/src/test/resources/META-INF/spring.factories +++ b/launchers/standalone/src/test/resources/META-INF/spring.factories @@ -3,8 +3,10 @@ com.tencent.supersonic.chat.api.component.SchemaMapper=\ com.tencent.supersonic.chat.api.component.SemanticParser=\ com.tencent.supersonic.chat.parser.rule.QueryModeParser, \ com.tencent.supersonic.chat.parser.rule.ContextInheritParser, \ + com.tencent.supersonic.chat.parser.rule.AgentCheckParser, \ com.tencent.supersonic.chat.parser.rule.TimeRangeParser, \ - com.tencent.supersonic.chat.parser.rule.AggregateTypeParser + com.tencent.supersonic.chat.parser.rule.AggregateTypeParser, \ + com.tencent.supersonic.chat.parser.llm.interpret.MetricInterpretParser # com.tencent.supersonic.chat.parser.llm.DSLQueryFunction com.tencent.supersonic.chat.api.component.QueryProcessor=\ com.tencent.supersonic.chat.application.processor.SemanticQueryProcessor diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java index 0f2ef7470..1971a2c7d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java @@ -1,26 +1,13 @@ package com.tencent.supersonic.semantic.model.application; import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; -import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; -import com.tencent.supersonic.semantic.api.model.response.DatasourceResp; -import com.tencent.supersonic.semantic.api.model.response.DimensionResp; -import com.tencent.supersonic.semantic.api.model.response.ItemDateResp; -import com.tencent.supersonic.semantic.api.model.response.MetricResp; -import com.tencent.supersonic.semantic.api.model.response.ModelResp; +import com.tencent.supersonic.semantic.api.model.response.*; import com.tencent.supersonic.semantic.api.model.yaml.DatasourceYamlTpl; import com.tencent.supersonic.semantic.api.model.yaml.DimensionYamlTpl; import com.tencent.supersonic.semantic.api.model.yaml.MetricYamlTpl; -import com.tencent.supersonic.semantic.model.domain.Catalog; -import com.tencent.supersonic.semantic.model.domain.DatasourceService; -import com.tencent.supersonic.semantic.model.domain.DimensionService; -import com.tencent.supersonic.semantic.model.domain.MetricService; -import com.tencent.supersonic.semantic.model.domain.ModelService; -import com.tencent.supersonic.semantic.model.domain.dataobject.DatabaseDO; -import com.tencent.supersonic.semantic.model.domain.repository.DatabaseRepository; -import com.tencent.supersonic.semantic.model.domain.utils.DatabaseConverter; +import com.tencent.supersonic.semantic.model.domain.*; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; @@ -29,17 +16,17 @@ import org.springframework.stereotype.Component; @Component public class CatalogImpl implements Catalog { - private final DatabaseRepository databaseRepository; + private final DatabaseService databaseService; private final ModelService modelService; private final DimensionService dimensionService; private final DatasourceService datasourceService; private final MetricService metricService; - public CatalogImpl(DatabaseRepository databaseRepository, + public CatalogImpl(DatabaseService databaseService, ModelService modelService, DimensionService dimensionService, DatasourceService datasourceService, MetricService metricService) { - this.databaseRepository = databaseRepository; + this.databaseService = databaseService; this.modelService = modelService; this.dimensionService = dimensionService; this.datasourceService = datasourceService; @@ -47,14 +34,11 @@ public class CatalogImpl implements Catalog { } public DatabaseResp getDatabase(Long id) { - DatabaseDO databaseDO = databaseRepository.getDatabase(id); - return DatabaseConverter.convert(databaseDO); + return databaseService.getDatabase(id); } public DatabaseResp getDatabaseByModelId(Long modelId) { - List databaseDOS = databaseRepository.getDatabaseByDomainId(modelId); - Optional databaseDO = databaseDOS.stream().findFirst(); - return databaseDO.map(DatabaseConverter::convert).orElse(null); + return databaseService.getDatabaseByModelId(modelId); } @Override diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DimensionServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DimensionServiceImpl.java index b4792864f..0e8046f23 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DimensionServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DimensionServiceImpl.java @@ -1,27 +1,37 @@ package com.tencent.supersonic.semantic.model.application; +import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.TypeReference; import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; +import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.util.ChatGptHelper; +import com.tencent.supersonic.semantic.api.model.pojo.DatasourceDetail; +import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap; import com.tencent.supersonic.semantic.api.model.request.DimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; +import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; import com.tencent.supersonic.semantic.api.model.response.DatasourceResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.model.domain.DatabaseService; +import com.tencent.supersonic.semantic.model.domain.dataobject.DimensionDO; +import com.tencent.supersonic.semantic.model.domain.repository.DimensionRepository; +import com.tencent.supersonic.semantic.model.domain.utils.DimensionConverter; import com.tencent.supersonic.semantic.model.domain.DatasourceService; import com.tencent.supersonic.semantic.model.domain.DimensionService; import com.tencent.supersonic.semantic.model.domain.DomainService; -import com.tencent.supersonic.semantic.model.domain.dataobject.DimensionDO; import com.tencent.supersonic.semantic.model.domain.pojo.Dimension; import com.tencent.supersonic.semantic.model.domain.pojo.DimensionFilter; -import com.tencent.supersonic.semantic.model.domain.repository.DimensionRepository; -import com.tencent.supersonic.semantic.model.domain.utils.DimensionConverter; -import java.util.HashMap; -import java.util.List; -import java.util.Map; + +import java.util.*; import java.util.stream.Collectors; +import java.util.stream.Stream; + import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.stereotype.Service; @@ -39,13 +49,22 @@ public class DimensionServiceImpl implements DimensionService { private DomainService domainService; + private ChatGptHelper chatGptHelper; + + private DatabaseService databaseService; + public DimensionServiceImpl(DimensionRepository dimensionRepository, DomainService domainService, - DatasourceService datasourceService) { + DatasourceService datasourceService, + ChatGptHelper chatGptHelper, + DatabaseService databaseService) { this.domainService = domainService; this.dimensionRepository = dimensionRepository; this.datasourceService = datasourceService; + this.chatGptHelper = chatGptHelper; + this.databaseService = databaseService; + } @Override @@ -238,6 +257,53 @@ public class DimensionServiceImpl implements DimensionService { dimensionRepository.deleteDimension(id); } + @Override + public List mockAlias(DimensionReq dimensionReq, String mockType, User user) { + String mockAlias = chatGptHelper.mockAlias(mockType,dimensionReq.getName(), dimensionReq.getBizName(), "", dimensionReq.getDescription() ,false); + return JSONObject.parseObject(mockAlias, new TypeReference>() {}); + } + + @Override + public List mockDimensionValueAlias(DimensionReq dimensionReq, User user) { + + List datasourceList = datasourceService.getDatasourceList(); + List collect = datasourceList.stream().filter(datasourceResp -> datasourceResp.getId().equals(dimensionReq.getDatasourceId())).collect(Collectors.toList()); + + if (collect.isEmpty()){ + return null; + } + DatasourceResp datasourceResp = collect.get(0); + DatasourceDetail datasourceDetail = datasourceResp.getDatasourceDetail(); + String sqlQuery = datasourceDetail.getSqlQuery(); + + DatabaseResp database = databaseService.getDatabase(datasourceResp.getDatabaseId()); + + String sql = "select ai_talk."+dimensionReq.getBizName()+" from ("+sqlQuery +") as ai_talk group by ai_talk."+dimensionReq.getBizName(); + QueryResultWithSchemaResp queryResultWithSchemaResp = databaseService.executeSql(sql, database); + List> resultList = queryResultWithSchemaResp.getResultList(); + List valueList = new ArrayList<>(); + for (Map stringObjectMap : resultList) { + String value = (String) stringObjectMap.get(dimensionReq.getBizName()); + valueList.add(value); + } + String json = chatGptHelper.mockDimensionValueAlias(JSON.toJSONString(valueList)); + log.info("return llm res is :{}",json); + + JSONObject jsonObject = JSON.parseObject(json); + + List dimValueMapsResp = new ArrayList<>(); + int i = 0; + for (Map stringObjectMap : resultList) { + DimValueMap dimValueMap = new DimValueMap(); + dimValueMap.setTechName((String) stringObjectMap.get(dimensionReq.getBizName())); + dimValueMap.setBizName(jsonObject.getJSONArray("tran").getString(i)); + dimValueMap. setAlias(jsonObject.getJSONObject("alias").getJSONArray((String) stringObjectMap.get(dimensionReq.getBizName())).toJavaList(String.class)); + dimValueMapsResp.add(dimValueMap); + i ++ ; + } + return dimValueMapsResp; + } + private void checkExist(List dimensionReqs) { Long modelId = dimensionReqs.get(0).getModelId(); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java index 389328aaa..af7ed3d6d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java @@ -16,16 +16,8 @@ import com.tencent.supersonic.semantic.model.domain.dataobject.DomainDO; import com.tencent.supersonic.semantic.model.domain.pojo.Domain; import com.tencent.supersonic.semantic.model.domain.repository.DomainRepository; import com.tencent.supersonic.semantic.model.domain.utils.DomainConvert; -import java.util.ArrayList; -import java.util.Date; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Queue; -import java.util.Set; + +import java.util.*; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.assertj.core.util.Sets; @@ -101,7 +93,8 @@ public class DomainServiceImpl implements DomainService { List domainIds = modelResps.stream().map(ModelResp::getDomainId).collect(Collectors.toList()); domainWithAuthAll.addAll(getParentDomain(domainIds)); } - return new ArrayList<>(domainWithAuthAll); + return new ArrayList<>(domainWithAuthAll).stream() + .sorted(Comparator.comparingLong(DomainResp::getId)).collect(Collectors.toList()); } @Override diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java index 5075c3f16..50de1b114 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java @@ -1,26 +1,29 @@ package com.tencent.supersonic.semantic.model.application; import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.TypeReference; import com.github.pagehelper.PageHelper; import com.github.pagehelper.PageInfo; import com.google.common.collect.Lists; +import com.plexpt.chatgpt.ChatGPT; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; +import com.tencent.supersonic.common.util.ChatGptHelper; import com.tencent.supersonic.semantic.api.model.pojo.Measure; import com.tencent.supersonic.semantic.api.model.pojo.MetricTypeParams; import com.tencent.supersonic.semantic.api.model.request.MetricReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.response.DomainResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.model.domain.DomainService; -import com.tencent.supersonic.semantic.model.domain.MetricService; import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.model.domain.dataobject.MetricDO; -import com.tencent.supersonic.semantic.model.domain.pojo.Metric; import com.tencent.supersonic.semantic.model.domain.pojo.MetricFilter; import com.tencent.supersonic.semantic.model.domain.repository.MetricRepository; import com.tencent.supersonic.semantic.model.domain.utils.MetricConverter; +import com.tencent.supersonic.semantic.model.domain.MetricService; +import com.tencent.supersonic.semantic.model.domain.pojo.Metric; import java.util.List; import java.util.Map; import java.util.Objects; @@ -42,12 +45,16 @@ public class MetricServiceImpl implements MetricService { private DomainService domainService; + private ChatGptHelper chatGptHelper; + public MetricServiceImpl(MetricRepository metricRepository, - ModelService modelService, - DomainService domainService) { + ModelService modelService, + DomainService domainService, + ChatGptHelper chatGptHelper) { this.domainService = domainService; this.metricRepository = metricRepository; this.modelService = modelService; + this.chatGptHelper = chatGptHelper; } @Override @@ -74,7 +81,7 @@ public class MetricServiceImpl implements MetricService { log.info("[insert metric] object:{}", JSONObject.toJSONString(metricToInsert)); saveMetricBatch(metricToInsert, user); } - + @Override public List getMetrics(Long modelId) { return convertList(metricRepository.getMetricList(modelId)); @@ -201,6 +208,14 @@ public class MetricServiceImpl implements MetricService { metricRepository.deleteMetric(id); } + @Override + public List mockAlias(MetricReq metricReq,String mockType,User user) { + + String mockAlias = chatGptHelper.mockAlias(mockType,metricReq.getName(), metricReq.getBizName(), "", metricReq.getDescription() ,!"".equals(metricReq.getDataFormatType())); + return JSONObject.parseObject(mockAlias, new TypeReference>() {}); + } + + private void saveMetricBatch(List metrics, User user) { if (CollectionUtils.isEmpty(metrics)) { diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DimensionService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DimensionService.java index ff3928afa..10e6a46ea 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DimensionService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DimensionService.java @@ -2,7 +2,9 @@ package com.tencent.supersonic.semantic.model.domain; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap; import com.tencent.supersonic.semantic.api.model.request.DimensionReq; +import com.tencent.supersonic.semantic.api.model.request.MetricReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import java.util.List; @@ -32,4 +34,8 @@ public interface DimensionService { List getAllHighSensitiveDimension(); void deleteDimension(Long id) throws Exception; + + List mockAlias(DimensionReq dimensionReq, String mockType, User user); + + List mockDimensionValueAlias(DimensionReq dimensionReq, User user); } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java index a82d311c1..8f674fa19 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java @@ -32,4 +32,6 @@ public interface MetricService { List getAllHighSensitiveMetric(); void deleteMetric(Long id) throws Exception; + + List mockAlias(MetricReq metricReq,String mockType,User user); } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/DimensionConverter.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/DimensionConverter.java index 1280e1df7..59688fc7d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/DimensionConverter.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/DimensionConverter.java @@ -30,6 +30,8 @@ public class DimensionConverter { dimensionDO.setDefaultValues(JSONObject.toJSONString(dimension.getDefaultValues())); if (!CollectionUtils.isEmpty(dimension.getDimValueMaps())) { dimensionDO.setDimValueMaps(JSONObject.toJSONString(dimension.getDimValueMaps())); + } else { + dimensionDO.setDimValueMaps(JSONObject.toJSONString(new ArrayList<>())); } return dimensionDO; } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/DimensionController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/DimensionController.java index 3bc1264d1..a4dc762f4 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/DimensionController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/DimensionController.java @@ -3,13 +3,17 @@ package com.tencent.supersonic.semantic.model.rest; import com.github.pagehelper.PageInfo; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.semantic.api.model.pojo.DimValueMap; import com.tencent.supersonic.semantic.api.model.request.DimensionReq; +import com.tencent.supersonic.semantic.api.model.request.MetricReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.model.domain.DimensionService; import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; + +import com.tencent.supersonic.semantic.model.domain.MetricService; import org.springframework.web.bind.annotation.DeleteMapping; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; @@ -27,7 +31,10 @@ public class DimensionController { private DimensionService dimensionService; - public DimensionController(DimensionService dimensionService) { + private MetricService metricService; + + public DimensionController(DimensionService dimensionService,MetricService metricService) { + this.metricService = metricService; this.dimensionService = dimensionService; } @@ -56,6 +63,22 @@ public class DimensionController { return true; } + @PostMapping("/mockDimensionAlias") + public List mockMetricAlias(@RequestBody DimensionReq dimensionReq, + HttpServletRequest request, + HttpServletResponse response){ + User user = UserHolder.findUser(request, response); + return dimensionService.mockAlias(dimensionReq,"dimension",user); + } + + + @PostMapping("/mockDimensionValuesAlias") + public List mockDimensionValuesAlias(@RequestBody DimensionReq dimensionReq, + HttpServletRequest request, + HttpServletResponse response){ + User user = UserHolder.findUser(request, response); + return dimensionService.mockDimensionValueAlias(dimensionReq,user); + } @GetMapping("/getDimensionList/{modelId}") public List getDimension(@PathVariable("modelId") Long modelId) { diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java index b97b74340..7173ea637 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java @@ -11,13 +11,9 @@ import com.tencent.supersonic.semantic.model.domain.MetricService; import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import org.springframework.web.bind.annotation.DeleteMapping; -import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PathVariable; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RestController; + +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.web.bind.annotation.*; @RestController @@ -52,6 +48,14 @@ public class MetricController { } + @PostMapping("/mockMetricAlias") + public List mockMetricAlias(@RequestBody MetricReq metricReq, + HttpServletRequest request, + HttpServletResponse response){ + User user = UserHolder.findUser(request, response); + return metricService.mockAlias(metricReq,"indicator",user); + } + @GetMapping("/getMetricList/{modelId}") public List getMetricList(@PathVariable("modelId") Long modelId) { return metricService.getMetrics(modelId); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/CalculateAggConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/CalculateAggConverter.java index 72911f6ef..e8739721d 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/CalculateAggConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/CalculateAggConverter.java @@ -416,7 +416,7 @@ public class CalculateAggConverter implements SemanticConverter { } private static String getLimit(QueryStructReq queryStructCmd) { - if (queryStructCmd.getLimit() > 0) { + if (queryStructCmd != null && queryStructCmd.getLimit() > 0) { return " limit " + String.valueOf(queryStructCmd.getLimit()); } return ""; diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/ParserDefaultConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/ParserDefaultConverter.java index 598c4c6e6..e2f44efec 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/ParserDefaultConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/convert/ParserDefaultConverter.java @@ -23,8 +23,7 @@ import org.springframework.util.CollectionUtils; @Slf4j public class ParserDefaultConverter implements SemanticConverter { - @Value("${internal.metric.cnt.suffix:internal_cnt}") - private String internalMetricNameSuffix; + private final CalculateAggConverter calculateCoverterAgg; private final QueryStructUtils queryStructUtils; @@ -69,7 +68,7 @@ public class ParserDefaultConverter implements SemanticConverter { // todo tmp delete // support detail query if (queryStructCmd.getNativeQuery() && CollectionUtils.isEmpty(sqlCommend.getMetrics())) { - String internalMetricName = generateInternalMetricName(catalog, queryStructCmd); + String internalMetricName = queryStructUtils.generateInternalMetricName(queryStructCmd.getModelId(), queryStructCmd.getGroups()); sqlCommend.getMetrics().add(internalMetricName); } @@ -77,21 +76,5 @@ public class ParserDefaultConverter implements SemanticConverter { } - public String generateInternalMetricName(Catalog catalog, QueryStructReq queryStructCmd) { - String internalMetricNamePrefix = ""; - if (CollectionUtils.isEmpty(queryStructCmd.getGroups())) { - log.warn("group is empty!"); - } else { - String group = queryStructCmd.getGroups().get(0).equalsIgnoreCase("sys_imp_date") - ? queryStructCmd.getGroups().get(1) : queryStructCmd.getGroups().get(0); - DimensionResp dimension = catalog.getDimension(group, queryStructCmd.getModelId()); - String datasourceBizName = dimension.getDatasourceBizName(); - if (Strings.isNotEmpty(datasourceBizName)) { - internalMetricNamePrefix = datasourceBizName + UNDERLINE; - } - } - String internalMetricName = internalMetricNamePrefix + internalMetricNameSuffix; - return internalMetricName; - } } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java index 68eef591b..faf5adb5d 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java @@ -50,7 +50,7 @@ public class QueryController { HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); - return queryService.queryByStruct(queryStructReq, user, request); + return queryService.queryByStructWithAuth(queryStructReq, user); } @PostMapping("/struct/parse") diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java index 5a79395c5..78cb66577 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java @@ -1,7 +1,6 @@ package com.tencent.supersonic.semantic.query.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; -import com.tencent.supersonic.semantic.api.model.pojo.QueryStat; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.request.ItemUseReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; @@ -9,22 +8,18 @@ import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.response.ItemUseResp; import java.util.List; -import javax.servlet.http.HttpServletRequest; public interface QueryService { - Object queryBySql(QueryDslReq querySqlCmd, User user) throws Exception; QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructCmd, User user) throws Exception; - QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructCmd, User user, HttpServletRequest request) + QueryResultWithSchemaResp queryByStructWithAuth(QueryStructReq queryStructCmd, User user) throws Exception; QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructCmd, User user) throws Exception; List getStatInfo(ItemUseReq itemUseCommend); - List getQueryStatInfoWithoutCache(ItemUseReq itemUseCommend); - } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java index 6bbb5aec9..414fb7d6d 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java @@ -22,7 +22,7 @@ import com.tencent.supersonic.semantic.query.utils.StatUtils; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import javax.servlet.http.HttpServletRequest; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.factory.annotation.Value; @@ -110,8 +110,8 @@ public class QueryServiceImpl implements QueryService { @Override @DataPermission - public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructCmd, User user, HttpServletRequest request) - throws Exception { + @SneakyThrows + public QueryResultWithSchemaResp queryByStructWithAuth(QueryStructReq queryStructCmd, User user) { return queryByStruct(queryStructCmd, user); } @@ -171,12 +171,6 @@ public class QueryServiceImpl implements QueryService { return statInfos; } - - @Override - public List getQueryStatInfoWithoutCache(ItemUseReq itemUseCommend) { - return statUtils.getQueryStatInfoWithoutCache(itemUseCommend); - } - private boolean isCache(QueryStructReq queryStructCmd) { if (!cacheEnable) { return false; diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java index 9fdd3ce46..d7f5cfd49 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java @@ -90,11 +90,15 @@ public class DataPermissionAOP { if (Objects.isNull(user) || Strings.isNullOrEmpty(user.getName())) { throw new RuntimeException("lease provide user information"); } + //1. determine whether admin of the model + if (doModelAdmin(user, queryStructReq)) { + return point.proceed(); + } - // 1. determine whether the subject field is visible - doDomainVisible(user, queryStructReq); + // 2. determine whether the subject field is visible + doModelVisible(user, queryStructReq); - // 2. fetch data permission meta information + // 3. fetch data permission meta information Long modelId = queryStructReq.getModelId(); Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryStructReq); log.info("modelId:{}, res4Privilege:{}", modelId, res4Privilege); @@ -105,18 +109,17 @@ public class DataPermissionAOP { log.info("this query domainId:{}, sensitiveResReq:{}", modelId, sensitiveResReq); // query user privilege info - HttpServletRequest request = (HttpServletRequest) args[2]; - AuthorizedResourceResp authorizedResource = getAuthorizedResource(user, request, modelId, sensitiveResReq); + AuthorizedResourceResp authorizedResource = getAuthorizedResource(user, modelId, sensitiveResReq); // get sensitiveRes that user has privilege Set resAuthSet = getAuthResNameSet(authorizedResource, queryStructReq.getModelId()); - // 3.if sensitive fields without permission are involved in filter, thrown an exception + // 4.if sensitive fields without permission are involved in filter, thrown an exception doFilterCheckLogic(queryStructReq, resAuthSet, sensitiveResReq); - // 4.row permission pre-filter + // 5.row permission pre-filter doRowPermission(queryStructReq, authorizedResource); - // 5.proceed + // 6.proceed QueryResultWithSchemaResp queryResultWithColumns = (QueryResultWithSchemaResp) point.proceed(); if (CollectionUtils.isEmpty(sensitiveResReq) || allSensitiveResReqIsOk(sensitiveResReq, resAuthSet)) { @@ -136,7 +139,19 @@ public class DataPermissionAOP { } - private void doDomainVisible(User user, QueryStructReq queryStructCmd) { + private boolean doModelAdmin(User user, QueryStructReq queryStructCmd) { + Long modelId = queryStructCmd.getModelId(); + List modelListAdmin = modelService.getModelListWithAuth(user.getName(), null, AuthType.ADMIN); + if (CollectionUtils.isEmpty(modelListAdmin)) { + return false; + } else { + Map> id2modelResp = modelListAdmin.stream() + .collect(Collectors.groupingBy(SchemaItem::getId)); + return !CollectionUtils.isEmpty(id2modelResp) && id2modelResp.containsKey(modelId); + } + } + + private void doModelVisible(User user, QueryStructReq queryStructCmd) { Boolean visible = true; Long domainId = queryStructCmd.getModelId(); List modelListVisible = modelService.getModelListWithAuth(user.getName(), null, AuthType.VISIBLE); @@ -251,15 +266,14 @@ public class DataPermissionAOP { return resAuthName; } - private AuthorizedResourceResp getAuthorizedResource(User user, HttpServletRequest request, Long domainId, + private AuthorizedResourceResp getAuthorizedResource(User user, Long domainId, Set sensitiveResReq) { List resourceReqList = new ArrayList<>(); - sensitiveResReq.stream().forEach(res -> resourceReqList.add(new AuthRes(domainId.toString(), res))); + sensitiveResReq.forEach(res -> resourceReqList.add(new AuthRes(domainId.toString(), res))); QueryAuthResReq queryAuthResReq = new QueryAuthResReq(); - queryAuthResReq.setUser(user.getName()); queryAuthResReq.setResources(resourceReqList); queryAuthResReq.setModelId(domainId + ""); - AuthorizedResourceResp authorizedResource = fetchAuthRes(request, queryAuthResReq); + AuthorizedResourceResp authorizedResource = fetchAuthRes(queryAuthResReq, user); log.info("user:{}, domainId:{}, after queryAuthorizedResources:{}", user.getName(), domainId, authorizedResource); return authorizedResource; @@ -396,10 +410,9 @@ public class DataPermissionAOP { } - private AuthorizedResourceResp fetchAuthRes(HttpServletRequest request, QueryAuthResReq queryAuthResReq) { - log.info("Authorization:{}", request.getHeader("Authorization")); + private AuthorizedResourceResp fetchAuthRes(QueryAuthResReq queryAuthResReq, User user) { log.info("queryAuthResReq:{}", queryAuthResReq); - return authService.queryAuthorizedResources(queryAuthResReq, request); + return authService.queryAuthorizedResources(queryAuthResReq, user); } } \ No newline at end of file diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java index 6bdb2e08c..a37c2a704 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DimValueAspect.java @@ -35,7 +35,8 @@ public class DimValueAspect { private DimensionService dimensionService; @Around("execution(* com.tencent.supersonic.semantic.query.rest.QueryController.queryByStruct(..))" + - " || execution(* com.tencent.supersonic.semantic.query.service.QueryService.queryByStruct(..))") + " || execution(* com.tencent.supersonic.semantic.query.service.QueryService.queryByStruct(..))" + + " || execution(* com.tencent.supersonic.semantic.query.service.QueryService.queryByStructWithAuth(..))") public Object handleDimValue(ProceedingJoinPoint joinPoint) throws Throwable { if (!dimensionValueMapEnable) { diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryReqConverter.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryReqConverter.java index 2c0dda7ac..6bcb6738c 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryReqConverter.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryReqConverter.java @@ -10,6 +10,7 @@ import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.query.persistence.pojo.QueryStatement; import com.tencent.supersonic.semantic.query.service.SemanticQueryEngine; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -28,6 +29,9 @@ public class QueryReqConverter { @Autowired private SemanticQueryEngine parserService; + @Autowired + private QueryStructUtils queryStructUtils; + public QueryStatement convert(QueryDslReq databaseReq, List domainSchemas) throws Exception { List tables = new ArrayList<>(); @@ -60,6 +64,12 @@ public class QueryReqConverter { } metricTable.setDimensions(new ArrayList<>(collect)); metricTable.setAlias(tableName.toLowerCase()); + // if metric empty , fill model default + if (CollectionUtils.isEmpty(metricTable.getMetrics())) { + metricTable.setMetrics(new ArrayList<>(Arrays.asList( + queryStructUtils.generateInternalMetricName(databaseReq.getModelId(), + metricTable.getDimensions())))); + } tables.add(metricTable); ParseSqlReq result = new ParseSqlReq(); diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java index b3efc6d44..4c1ede9f8 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java @@ -1,16 +1,19 @@ package com.tencent.supersonic.semantic.query.utils; -import com.tencent.supersonic.common.pojo.Aggregator; -import com.tencent.supersonic.common.pojo.DateConf; +import static com.tencent.supersonic.common.pojo.Constants.UNDERLINE; + import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.enums.TypeEnums; -import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; +import com.tencent.supersonic.common.pojo.Aggregator; +import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; +import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ItemDateResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.model.domain.Catalog; + import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -19,9 +22,12 @@ import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; + import lombok.extern.slf4j.Slf4j; import org.apache.logging.log4j.util.Strings; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; @Slf4j @@ -32,6 +38,8 @@ public class QueryStructUtils { private final SqlFilterUtils sqlFilterUtils; private final Catalog catalog; + @Value("${internal.metric.cnt.suffix:internal_cnt}") + private String internalMetricNameSuffix; public static Set internalCols = new HashSet<>( Arrays.asList("dayno", "plat_sys_var", "sys_imp_date", "sys_imp_week", "sys_imp_month")); @@ -157,5 +165,23 @@ public class QueryStructUtils { return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); } + public String generateInternalMetricName(Long modelId, List groups) { + String internalMetricNamePrefix = ""; + if (CollectionUtils.isEmpty(groups)) { + log.warn("group is empty!"); + } else { + String group = groups.get(0).equalsIgnoreCase("sys_imp_date") + ? groups.get(1) : groups.get(0); + DimensionResp dimension = catalog.getDimension(group, modelId); + String datasourceBizName = dimension.getDatasourceBizName(); + if (Strings.isNotEmpty(datasourceBizName)) { + internalMetricNamePrefix = datasourceBizName + UNDERLINE; + } + + } + String internalMetricName = internalMetricNamePrefix + internalMetricNameSuffix; + return internalMetricName; + } + } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryUtils.java index 18b9e3fee..2fd4d5bab 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryUtils.java @@ -3,11 +3,11 @@ package com.tencent.supersonic.semantic.query.utils; import static com.tencent.supersonic.common.pojo.Constants.JOIN_UNDERLINE; import static com.tencent.supersonic.common.pojo.Constants.UNIONALL; -import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.util.cache.CacheUtils; import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; +import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; @@ -67,8 +67,8 @@ public class QueryUtils { public void fillItemNameInfo(QueryResultWithSchemaResp queryResultWithColumns, Long modelId) { List metricDescList = catalog.getMetrics(modelId); List dimensionDescList = catalog.getDimensions(modelId); - Map metricRespMap = - metricDescList.stream().collect(Collectors.toMap(MetricResp::getBizName, a -> a, (k1, k2) -> k1)); + Map metricRespMap = + metricDescList.stream().collect(Collectors.toMap(MetricResp::getBizName, a -> a,(k1, k2)->k1)); Map namePair = new HashMap<>(); Map nameTypePair = new HashMap<>(); addSysTimeDimension(namePair, nameTypePair); @@ -92,13 +92,27 @@ public class QueryUtils { if (nameTypePair.containsKey(nameEn)) { column.setShowType(nameTypePair.get(nameEn)); } - if (metricRespMap.containsKey(nameEn)) { + if (!nameTypePair.containsKey(nameEn) && isNumberType(column.getType())) { + column.setShowType("NUMBER"); + } + if(metricRespMap.containsKey(nameEn)){ column.setDataFormatType(metricRespMap.get(nameEn).getDataFormatType()); column.setDataFormat(metricRespMap.get(nameEn).getDataFormat()); } }); } + private boolean isNumberType(String type) { + if (StringUtils.isBlank(type)) { + return false; + } + if (type.equalsIgnoreCase("int") || type.equalsIgnoreCase("bigint") + || type.equalsIgnoreCase("float") || type.equalsIgnoreCase("double")) { + return true; + } + return false; + } + public void fillItemNameInfo(QueryResultWithSchemaResp queryResultWithColumns, QueryMultiStructReq queryMultiStructCmd) { List aggregators = queryMultiStructCmd.getQueryStructReqs().stream()