diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/annotation/AuthenticationIgnore.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/annotation/AuthenticationIgnore.java index 6752247eb..095614ffe 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/annotation/AuthenticationIgnore.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/annotation/AuthenticationIgnore.java @@ -7,4 +7,5 @@ import java.lang.annotation.Target; @Target({ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) -public @interface AuthenticationIgnore {} +public @interface AuthenticationIgnore { +} diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/config/AuthenticationConfig.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/config/AuthenticationConfig.java index 207532489..c014524c6 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/config/AuthenticationConfig.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/config/AuthenticationConfig.java @@ -24,9 +24,8 @@ public class AuthenticationConfig { @Value("${s2.authentication.token.default.appKey:supersonic}") private String tokenDefaultAppKey; - @Value( - "${s2.authentication.token.appSecret:supersonic:WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk" - + "783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==}") + @Value("${s2.authentication.token.appSecret:supersonic:WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk" + + "783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==}") private String tokenAppSecret; @Value("${s2.authentication.token.http.header.key:Authorization}") @@ -48,8 +47,7 @@ public class AuthenticationConfig { private Long tokenTimeout; public Map getAppKeyToSecretMap() { - return Arrays.stream(this.tokenAppSecret.split(",")) - .map(s -> s.split(":")) + return Arrays.stream(this.tokenAppSecret.split(",")).map(s -> s.split(":")) .collect(Collectors.toMap(e -> e[0].trim(), e -> e[1].trim())); } } diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java index 577bbb0fe..adaf94814 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java @@ -20,8 +20,8 @@ public class User { private Integer isAdmin; - public static User get( - Long id, String name, String displayName, String email, Integer isAdmin) { + public static User get(Long id, String name, String displayName, String email, + Integer isAdmin) { return new User(id, name, displayName, email, isAdmin); } diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java index a8b161028..e02931623 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java @@ -9,24 +9,14 @@ public class UserWithPassword extends User { private String password; - public UserWithPassword( - Long id, - String name, - String displayName, - String email, - String password, + public UserWithPassword(Long id, String name, String displayName, String email, String password, Integer isAdmin) { super(id, name, displayName, email, isAdmin); this.password = password; } - public static UserWithPassword get( - Long id, - String name, - String displayName, - String email, - String password, - Integer isAdmin) { + public static UserWithPassword get(Long id, String name, String displayName, String email, + String password, Integer isAdmin) { return new UserWithPassword(id, name, displayName, email, password, isAdmin); } } diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java index 393e0531d..0d9296428 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/service/UserService.java @@ -12,8 +12,8 @@ import java.util.Set; public interface UserService { - User getCurrentUser( - HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse); + User getCurrentUser(HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse); List getUserNames(); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java index 41d6f60d3..f43bc17ce 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java @@ -52,12 +52,10 @@ public class DefaultUserAdaptor implements UserAdaptor { new Organization("1", "0", "SuperSonic", "SuperSonic", Lists.newArrayList(), true); Organization hr = new Organization("2", "1", "Hr", "SuperSonic/Hr", Lists.newArrayList(), false); - Organization sales = - new Organization( - "3", "1", "Sales", "SuperSonic/Sales", Lists.newArrayList(), false); - Organization marketing = - new Organization( - "4", "1", "Marketing", "SuperSonic/Marketing", Lists.newArrayList(), false); + Organization sales = new Organization("3", "1", "Sales", "SuperSonic/Sales", + Lists.newArrayList(), false); + Organization marketing = new Organization("4", "1", "Marketing", "SuperSonic/Marketing", + Lists.newArrayList(), false); List subOrganization = Lists.newArrayList(hr, sales, marketing); superSonic.setSubOrganizations(subOrganization); return Lists.newArrayList(superSonic); @@ -113,19 +111,12 @@ public class DefaultUserAdaptor implements UserAdaptor { throw new RuntimeException("user not exist,please register"); } try { - String password = - AESEncryptionUtil.encrypt( - userReq.getPassword(), - AESEncryptionUtil.getBytesFromString(userDO.getSalt())); + String password = AESEncryptionUtil.encrypt(userReq.getPassword(), + AESEncryptionUtil.getBytesFromString(userDO.getSalt())); if (userDO.getPassword().equals(password)) { - UserWithPassword user = - UserWithPassword.get( - userDO.getId(), - userDO.getName(), - userDO.getDisplayName(), - userDO.getEmail(), - userDO.getPassword(), - userDO.getIsAdmin()); + UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), + userDO.getDisplayName(), userDO.getEmail(), userDO.getPassword(), + userDO.getIsAdmin()); return user; } else { throw new RuntimeException("password not correct, please try again"); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/AuthenticationInterceptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/AuthenticationInterceptor.java index 0ba60b894..e931d853d 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/AuthenticationInterceptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/AuthenticationInterceptor.java @@ -68,8 +68,8 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor { try { if (request instanceof StandardMultipartHttpServletRequest) { RequestFacade servletRequest = - (RequestFacade) - ((StandardMultipartHttpServletRequest) request).getRequest(); + (RequestFacade) ((StandardMultipartHttpServletRequest) request) + .getRequest(); Class servletRequestClazz = servletRequest.getClass(); Field request1 = servletRequestClazz.getDeclaredField("request"); request1.setAccessible(true); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/DefaultAuthenticationInterceptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/DefaultAuthenticationInterceptor.java index b0bdbcd1d..3aa00bb79 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/DefaultAuthenticationInterceptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/DefaultAuthenticationInterceptor.java @@ -22,9 +22,8 @@ import java.lang.reflect.Method; public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor { @Override - public boolean preHandle( - HttpServletRequest request, HttpServletResponse response, Object handler) - throws AccessException { + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, + Object handler) throws AccessException { authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class); userServiceImpl = ContextUtils.getBean(UserServiceImpl.class); userTokenUtils = ContextUtils.getBean(UserTokenUtils.class); @@ -74,11 +73,9 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor } private void setContext(String userName, HttpServletRequest request) { - ThreadContext threadContext = - ThreadContext.builder() - .token(request.getHeader(authenticationConfig.getTokenHttpHeaderKey())) - .userName(userName) - .build(); + ThreadContext threadContext = ThreadContext.builder() + .token(request.getHeader(authenticationConfig.getTokenHttpHeaderKey())) + .userName(userName).build(); s2ThreadContext.set(threadContext); } } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/InterceptorFactory.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/InterceptorFactory.java index b688c2fbb..e1b1620ee 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/InterceptorFactory.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/interceptor/InterceptorFactory.java @@ -13,17 +13,14 @@ public class InterceptorFactory implements WebMvcConfigurer { private List authenticationInterceptors; public InterceptorFactory() { - authenticationInterceptors = - SpringFactoriesLoader.loadFactories( - AuthenticationInterceptor.class, - Thread.currentThread().getContextClassLoader()); + authenticationInterceptors = SpringFactoriesLoader.loadFactories( + AuthenticationInterceptor.class, Thread.currentThread().getContextClassLoader()); } @Override public void addInterceptors(InterceptorRegistry registry) { for (AuthenticationInterceptor authenticationInterceptor : authenticationInterceptors) { - registry.addInterceptor(authenticationInterceptor) - .addPathPatterns("/**") + registry.addInterceptor(authenticationInterceptor).addPathPatterns("/**") .excludePathPatterns("/", "/webapp/**", "/error"); } } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java index 7c288784e..7e869521d 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java @@ -138,8 +138,8 @@ public class UserDOExample { criteria.add(new Criterion(condition, value)); } - protected void addCriterion( - String condition, Object value1, Object value2, String property) { + protected void addCriterion(String condition, Object value1, Object value2, + String property) { if (value1 == null || value2 == null) { throw new RuntimeException("Between values for " + property + " cannot be null"); } @@ -628,8 +628,8 @@ public class UserDOExample { this(condition, value, null); } - protected Criterion( - String condition, Object value, Object secondValue, String typeHandler) { + protected Criterion(String condition, Object value, Object secondValue, + String typeHandler) { super(); this.condition = condition; this.value = value; diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/rest/UserController.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/rest/UserController.java index d37abddee..a5de7bbc4 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/rest/UserController.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/rest/UserController.java @@ -30,8 +30,8 @@ public class UserController { } @GetMapping("/getCurrentUser") - public User getCurrentUser( - HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + public User getCurrentUser(HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { return userService.getCurrentUser(httpServletRequest, httpServletResponse); } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java index 5f321d1c8..11f74bba6 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/service/UserServiceImpl.java @@ -27,8 +27,8 @@ public class UserServiceImpl implements UserService { } @Override - public User getCurrentUser( - HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { + public User getCurrentUser(HttpServletRequest httpServletRequest, + HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); if (user != null) { SystemConfig systemConfig = sysParameterService.getSystemConfig(); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/UserStrategyFactory.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/UserStrategyFactory.java index d1f9758ba..b7becf8cd 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/UserStrategyFactory.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/strategy/UserStrategyFactory.java @@ -18,8 +18,8 @@ public class UserStrategyFactory { private AuthenticationConfig authenticationConfig; - public UserStrategyFactory( - AuthenticationConfig authenticationConfig, List userStrategyList) { + public UserStrategyFactory(AuthenticationConfig authenticationConfig, + List userStrategyList) { this.authenticationConfig = authenticationConfig; this.userStrategyList = userStrategyList; } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/ComponentFactory.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/ComponentFactory.java index 8176a4604..1afd8b3d1 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/ComponentFactory.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/ComponentFactory.java @@ -17,8 +17,7 @@ public class ComponentFactory { } private static T init(Class factoryType) { - return SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader()) - .get(0); + return SpringFactoriesLoader + .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java index 2bd5ed451..af7b80788 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java @@ -48,8 +48,7 @@ public class UserTokenUtils { Map claims = new HashMap<>(5); claims.put(TOKEN_USER_ID, user.getId()); claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName()); - claims.put( - TOKEN_USER_PASSWORD, + claims.put(TOKEN_USER_PASSWORD, StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword()); claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName()); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); @@ -83,10 +82,8 @@ public class UserTokenUtils { String userName = String.valueOf(claims.get(TOKEN_USER_NAME)); String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); - Integer isAdmin = - claims.get(TOKEN_IS_ADMIN) == null - ? 0 - : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); + Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null ? 0 + : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); return User.get(userId, userName, displayName, email, isAdmin); } @@ -105,10 +102,8 @@ public class UserTokenUtils { String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD)); - Integer isAdmin = - claims.get(TOKEN_IS_ADMIN) == null - ? 0 - : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); + Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null ? 0 + : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin); } @@ -121,11 +116,8 @@ public class UserTokenUtils { try { String tokenSecret = getTokenSecret(appKey); Claims claims = - Jwts.parser() - .setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8)) - .build() - .parseClaimsJws(getTokenString(token)) - .getBody(); + Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8)) + .build().parseClaimsJws(getTokenString(token)).getBody(); return Optional.of(claims); } catch (Exception e) { log.info("can not getClaims from appKey:{} token:{}, please login", appKey, token); @@ -149,15 +141,10 @@ public class UserTokenUtils { Date expirationDate = new Date(expiration); String tokenSecret = getTokenSecret(appKey); - return Jwts.builder() - .setClaims(claims) - .setSubject(claims.get(TOKEN_USER_NAME).toString()) + return Jwts.builder().setClaims(claims).setSubject(claims.get(TOKEN_USER_NAME).toString()) .setExpiration(expirationDate) - .signWith( - new SecretKeySpec( - tokenSecret.getBytes(StandardCharsets.UTF_8), - SignatureAlgorithm.HS512.getJcaName()), - SignatureAlgorithm.HS512) + .signWith(new SecretKeySpec(tokenSecret.getBytes(StandardCharsets.UTF_8), + SignatureAlgorithm.HS512.getJcaName()), SignatureAlgorithm.HS512) .compact(); } diff --git a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java index f4e52ff9a..2fd7bca0e 100644 --- a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java +++ b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java @@ -31,8 +31,7 @@ public class AuthController { } @GetMapping("/queryGroup") - public List queryAuthGroup( - @RequestParam("modelId") String modelId, + public List queryAuthGroup(@RequestParam("modelId") String modelId, @RequestParam(value = "groupId", required = false) Integer groupId) { return authService.queryAuthGroups(modelId, groupId); } @@ -69,10 +68,8 @@ public class AuthController { * @return */ @PostMapping("/queryAuthorizedRes") - public AuthorizedResourceResp queryAuthorizedResources( - @RequestBody QueryAuthResReq req, - HttpServletRequest request, - HttpServletResponse response) { + public AuthorizedResourceResp queryAuthorizedResources(@RequestBody QueryAuthResReq req, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return authService.queryAuthorizedResources(req, user); } diff --git a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/service/AuthServiceImpl.java b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/service/AuthServiceImpl.java index d6b899e8e..5bf756e53 100644 --- a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/service/AuthServiceImpl.java +++ b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/service/AuthServiceImpl.java @@ -39,18 +39,15 @@ public class AuthServiceImpl implements AuthService { List rows = jdbcTemplate.queryForList("select config from s2_auth_groups", String.class); Gson g = new Gson(); - return rows.stream() - .map(row -> g.fromJson(row, AuthGroup.class)) + return rows.stream().map(row -> g.fromJson(row, AuthGroup.class)) .collect(Collectors.toList()); } @Override public List queryAuthGroups(String modelId, Integer groupId) { return load().stream() - .filter( - group -> - (Objects.isNull(groupId) || groupId.equals(group.getGroupId())) - && modelId.equals(group.getModelId().toString())) + .filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId())) + && modelId.equals(group.getModelId().toString())) .collect(Collectors.toList()); } @@ -65,15 +62,11 @@ public class AuthServiceImpl implements AuthService { nextGroupId = obj + 1; } group.setGroupId(nextGroupId); - jdbcTemplate.update( - "insert into s2_auth_groups (group_id, config) values (?, ?);", - nextGroupId, - g.toJson(group)); + jdbcTemplate.update("insert into s2_auth_groups (group_id, config) values (?, ?);", + nextGroupId, g.toJson(group)); } else { - jdbcTemplate.update( - "update s2_auth_groups set config = ? where group_id = ?;", - g.toJson(group), - group.getGroupId()); + jdbcTemplate.update("update s2_auth_groups set config = ? where group_id = ?;", + g.toJson(group), group.getGroupId()); } } @@ -119,30 +112,24 @@ public class AuthServiceImpl implements AuthService { return resource; } - private List getAuthGroups( - List modelIds, String userName, List departmentIds) { - List groups = - load().stream() - .filter( - group -> { - if (!modelIds.contains(group.getModelId())) { - return false; - } - if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) - && group.getAuthorizedUsers().contains(userName)) { - return true; - } - for (String departmentId : departmentIds) { - if (!CollectionUtils.isEmpty( - group.getAuthorizedDepartmentIds()) - && group.getAuthorizedDepartmentIds() - .contains(departmentId)) { - return true; - } - } - return false; - }) - .collect(Collectors.toList()); + private List getAuthGroups(List modelIds, String userName, + List departmentIds) { + List groups = load().stream().filter(group -> { + if (!modelIds.contains(group.getModelId())) { + return false; + } + if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) + && group.getAuthorizedUsers().contains(userName)) { + return true; + } + for (String departmentId : departmentIds) { + if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds()) + && group.getAuthorizedDepartmentIds().contains(departmentId)) { + return true; + } + } + return false; + }).collect(Collectors.toList()); log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups); return groups; } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java index ecea94ef0..03e04fd89 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryReviewResult.java @@ -4,8 +4,7 @@ import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException; import org.apache.commons.lang3.StringUtils; public enum MemoryReviewResult { - POSITIVE, - NEGATIVE; + POSITIVE, NEGATIVE; public static MemoryReviewResult getMemoryReviewResult(String value) { String validValue = StringUtils.trim(value); diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java index 6850a3b1d..6474be963 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/enums/MemoryStatus.java @@ -1,7 +1,5 @@ package com.tencent.supersonic.chat.api.pojo.enums; public enum MemoryStatus { - PENDING, - ENABLED, - DISABLED; + PENDING, ENABLED, DISABLED; } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/KnowledgeInfoReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/KnowledgeInfoReq.java index 739e2a042..f15e6db3f 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/KnowledgeInfoReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/KnowledgeInfoReq.java @@ -14,7 +14,8 @@ public class KnowledgeInfoReq { private String bizName; /** type: IntentionTypeEnum temporarily only supports dimension-related information */ - @NotNull private TypeEnums type = TypeEnums.DIMENSION; + @NotNull + private TypeEnums type = TypeEnums.DIMENSION; private Boolean searchEnable = false; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java index 87ffae79f..8605d7db5 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/Agent.java @@ -43,16 +43,12 @@ public class Agent extends RecordInfo { return Lists.newArrayList(); } List toolList = (List) map.get("tools"); - return toolList.stream() - .filter( - tool -> { - if (Objects.isNull(type)) { - return true; - } - return type.name().equals(tool.get("type")); - }) - .map(JSONObject::toJSONString) - .collect(Collectors.toList()); + return toolList.stream().filter(tool -> { + if (Objects.isNull(type)) { + return true; + } + return type.name().equals(tool.get("type")); + }).map(JSONObject::toJSONString).collect(Collectors.toList()); } public boolean enableSearch() { @@ -72,8 +68,7 @@ public class Agent extends RecordInfo { if (CollectionUtils.isEmpty(tools)) { return Lists.newArrayList(); } - return tools.stream() - .map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class)) + return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class)) .collect(Collectors.toList()); } @@ -120,10 +115,8 @@ public class Agent extends RecordInfo { if (CollectionUtils.isEmpty(commonAgentTools)) { return new HashSet<>(); } - return commonAgentTools.stream() - .map(NL2SQLTool::getDataSetIds) - .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)) - .flatMap(Collection::stream) + return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds) + .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream) .collect(Collectors.toSet()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java index ba04f4d21..45ea4355e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/agent/AgentToolType.java @@ -4,9 +4,7 @@ import java.util.HashMap; import java.util.Map; public enum AgentToolType { - NL2SQL_RULE("基于规则Text-to-SQL"), - NL2SQL_LLM("基于大模型Text-to-SQL"), - PLUGIN("第三方插件"); + NL2SQL_RULE("基于规则Text-to-SQL"), NL2SQL_LLM("基于大模型Text-to-SQL"), PLUGIN("第三方插件"); private String title; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java index 7427f0878..2b1f0030b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/PlainTextExecutor.java @@ -26,14 +26,10 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT public class PlainTextExecutor implements ChatQueryExecutor { - private static final String INSTRUCTION = - "" - + "#Role: You are a nice person to talk to.\n" - + "#Task: Respond quickly and nicely to the user." - + "#Rules: 1.ALWAYS use the same language as the input.\n" - + "#History Inputs: %s\n" - + "#Current Input: %s\n" - + "#Your response: "; + private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n" + + "#Task: Respond quickly and nicely to the user." + + "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n" + + "#Current Input: %s\n" + "#Your response: "; @Override public QueryResult execute(ExecuteContext executeContext) { @@ -41,11 +37,8 @@ public class PlainTextExecutor implements ChatQueryExecutor { return null; } - String promptStr = - String.format( - INSTRUCTION, - getHistoryInputs(executeContext), - executeContext.getQueryText()); + String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext), + executeContext.getQueryText()); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); AgentService agentService = ContextUtils.getBean(AgentService.class); @@ -74,18 +67,15 @@ public class PlainTextExecutor implements ChatQueryExecutor { Boolean globalMultiTurnConfig = Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); Boolean multiTurnConfig = - agentMultiTurnConfig != null - ? agentMultiTurnConfig.isEnableMultiTurn() + agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig; if (Boolean.TRUE.equals(multiTurnConfig)) { List queryResps = getHistoryQueries(executeContext.getChatId(), 5); - queryResps.stream() - .forEach( - p -> { - historyInput.append(p.getQueryText()); - historyInput.append(";"); - }); + queryResps.stream().forEach(p -> { + historyInput.append(p.getQueryText()); + historyInput.append(";"); + }); } return historyInput.toString(); @@ -93,18 +83,13 @@ public class PlainTextExecutor implements ChatQueryExecutor { private List getHistoryQueries(int chatId, int multiNum) { ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class); - List contextualParseInfoList = - chatManageService.getChatQueries(chatId).stream() - .filter( - q -> - Objects.nonNull(q.getQueryResult()) - && q.getQueryResult().getQueryState() - == QueryState.SUCCESS) - .collect(Collectors.toList()); + List contextualParseInfoList = chatManageService.getChatQueries(chatId).stream() + .filter(q -> Objects.nonNull(q.getQueryResult()) + && q.getQueryResult().getQueryState() == QueryState.SUCCESS) + .collect(Collectors.toList()); - List contextualList = - contextualParseInfoList.subList( - 0, Math.min(multiNum, contextualParseInfoList.size())); + List contextualList = contextualParseInfoList.subList(0, + Math.min(multiNum, contextualParseInfoList.size())); Collections.reverse(contextualList); return contextualList; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java index 2abdea5b7..41a99aac3 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/executor/SqlExecutor.java @@ -31,35 +31,26 @@ public class SqlExecutor implements ChatQueryExecutor { QueryResult queryResult = doExecute(executeContext); if (queryResult != null) { - String textResult = - ResultFormatter.transform2TextNew( - queryResult.getQueryColumns(), queryResult.getQueryResults()); + String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(), + queryResult.getQueryResults()); queryResult.setTextResult(textResult); if (queryResult.getQueryState().equals(QueryState.SUCCESS) && queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { Text2SQLExemplar exemplar = JsonUtil.toObject( - JsonUtil.toString( - executeContext - .getParseInfo() - .getProperties() - .get(Text2SQLExemplar.PROPERTY_KEY)), + JsonUtil.toString(executeContext.getParseInfo().getProperties() + .get(Text2SQLExemplar.PROPERTY_KEY)), Text2SQLExemplar.class); MemoryService memoryService = ContextUtils.getBean(MemoryService.class); - memoryService.createMemory( - ChatMemoryDO.builder() - .agentId(executeContext.getAgent().getId()) - .status(MemoryStatus.PENDING) - .question(exemplar.getQuestion()) - .sideInfo(exemplar.getSideInfo()) - .dbSchema(exemplar.getDbSchema()) - .s2sql(exemplar.getSql()) - .createdBy(executeContext.getUser().getName()) - .updatedBy(executeContext.getUser().getName()) - .createdAt(new Date()) - .build()); + memoryService.createMemory(ChatMemoryDO.builder() + .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING) + .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo()) + .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql()) + .createdBy(executeContext.getUser().getName()) + .updatedBy(executeContext.getUser().getName()).createdAt(new Date()) + .build()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java index 2a02f0be7..76d237e7a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/memory/MemoryReviewTask.java @@ -27,25 +27,22 @@ public class MemoryReviewTask { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private static final String INSTRUCTION = - "" - + "\n#Role: You are a senior data engineer experienced in writing SQL." - + "\n#Task: Your will be provided with a user question and the SQL written by junior engineer," - + "please take a review and give your opinion." - + "\n#Rules: " - + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." - + "2.NO NEED to include date filter in the where clause if not explicitly expressed in the `Question`." - + "\n#Question: %s" - + "\n#Schema: %s" - + "\n#SideInfo: %s" - + "\n#SQL: %s" - + "\n#Response: "; + private static final String INSTRUCTION = "" + + "\n#Role: You are a senior data engineer experienced in writing SQL." + + "\n#Task: Your will be provided with a user question and the SQL written by junior engineer," + + "please take a review and give your opinion." + "\n#Rules: " + + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + + "2.NO NEED to include date filter in the where clause if not explicitly expressed in the `Question`." + + "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s" + + "\n#Response: "; private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)"); - @Autowired private MemoryService memoryService; + @Autowired + private MemoryService memoryService; - @Autowired private AgentService agentService; + @Autowired + private AgentService agentService; @Scheduled(fixedDelay = 60 * 1000) public void review() { @@ -78,8 +75,8 @@ public class MemoryReviewTask { } private String createPromptString(ChatMemoryDO m) { - return String.format( - INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), m.getS2sql()); + return String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), + m.getS2sql()); } private void processResponse(String response, ChatMemoryDO m) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java index 1cbfa9311..1082b56e1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2PluginParser.java @@ -21,13 +21,10 @@ public class NL2PluginParser implements ChatQueryParser { return; } - pluginRecognizers.forEach( - pluginRecognizer -> { - pluginRecognizer.recognize(parseContext, parseResp); - log.info( - "{} recallResult:{}", - pluginRecognizer.getClass().getSimpleName(), - JsonUtil.toString(parseResp)); - }); + pluginRecognizers.forEach(pluginRecognizer -> { + pluginRecognizer.recognize(parseContext, parseResp); + log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(), + JsonUtil.toString(parseResp)); + }); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java index 9bf9b8419..4c2f58a25 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/NL2SQLParser.java @@ -52,33 +52,27 @@ public class NL2SQLParser implements ChatQueryParser { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private static final String REWRITE_USER_QUESTION_INSTRUCTION = - "" - + "#Role: You are a data product manager experienced in data requirements." - + "#Task: Your will be provided with current and history questions asked by a user," - + "along with their mapped schema elements(metric, dimension and value)," - + "please try understanding the semantics and rewrite a question." - + "#Rules: " - + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges." - + "2.ONLY respond with the rewritten question." - + "#Current Question: {{current_question}}" - + "#Current Mapped Schema: {{current_schema}}" - + "#History Question: {{history_question}}" - + "#History Mapped Schema: {{history_schema}}" - + "#History SQL: {{history_sql}}" - + "#Rewritten Question: "; + private static final String REWRITE_USER_QUESTION_INSTRUCTION = "" + + "#Role: You are a data product manager experienced in data requirements." + + "#Task: Your will be provided with current and history questions asked by a user," + + "along with their mapped schema elements(metric, dimension and value)," + + "please try understanding the semantics and rewrite a question." + "#Rules: " + + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges." + + "2.ONLY respond with the rewritten question." + + "#Current Question: {{current_question}}" + + "#Current Mapped Schema: {{current_schema}}" + + "#History Question: {{history_question}}" + + "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}" + + "#Rewritten Question: "; - private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = - "" - + "#Role: You are a data business partner who closely interacts with business people.\n" - + "#Task: Your will be provided with user input, system output and some examples, " - + "please respond shortly to teach user how to ask the right question, " - + "by using `Examples` as references." - + "#Rules: ALWAYS respond with the same language as the `Input`.\n" - + "#Input: {{user_question}}\n" - + "#Output: {{system_message}}\n" - + "#Examples: {{examples}}\n" - + "#Response: "; + private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = "" + + "#Role: You are a data business partner who closely interacts with business people.\n" + + "#Task: Your will be provided with user input, system output and some examples, " + + "please respond shortly to teach user how to ask the right question, " + + "by using `Examples` as references." + + "#Rules: ALWAYS respond with the same language as the `Input`.\n" + + "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n" + + "#Examples: {{examples}}\n" + "#Response: "; @Override public void parse(ParseContext parseContext, ParseResp parseResp) { @@ -100,13 +94,10 @@ public class NL2SQLParser implements ChatQueryParser { parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); } else { if (parseContext.enbaleLLM()) { - parseResp.setErrorMsg( - rewriteErrorMessage( - parseContext.getQueryText(), - text2SqlParseResp.getErrorMsg(), - queryNLReq.getDynamicExemplars(), - parseContext.getAgent().getExamples(), - parseContext.getAgent().getModelConfig())); + parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(), + text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(), + parseContext.getAgent().getExamples(), + parseContext.getAgent().getModelConfig())); } } parseResp.setState(text2SqlParseResp.getState()); @@ -141,40 +132,26 @@ public class NL2SQLParser implements ChatQueryParser { StringBuilder textBuilder = new StringBuilder(); textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" "); Optional metric = parseInfo.getMetrics().stream().findFirst(); - metric.ifPresent( - schemaElement -> - textBuilder.append("**指标:** ").append(schemaElement.getName()).append(" ")); - List dimensionNames = - parseInfo.getDimensions().stream() - .map(SchemaElement::getName) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + metric.ifPresent(schemaElement -> textBuilder.append("**指标:** ") + .append(schemaElement.getName()).append(" ")); + List dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName) + .filter(Objects::nonNull).collect(Collectors.toList()); if (!CollectionUtils.isEmpty(dimensionNames)) { textBuilder.append("**维度:** ").append(String.join(",", dimensionNames)); } textBuilder.append("\n\n**筛选条件:** \n"); if (parseInfo.getDateInfo() != null) { - textBuilder - .append("**数据时间:** ") - .append(parseInfo.getDateInfo().getStartDate()) - .append("~") - .append(parseInfo.getDateInfo().getEndDate()) - .append(" "); + textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate()) + .append("~").append(parseInfo.getDateInfo().getEndDate()).append(" "); } if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters()) || CollectionUtils.isEmpty(parseInfo.getMetricFilters())) { Set queryFilters = parseInfo.getDimensionFilters(); queryFilters.addAll(parseInfo.getMetricFilters()); for (QueryFilter queryFilter : queryFilters) { - textBuilder - .append("**") - .append(queryFilter.getName()) - .append("**") - .append(" ") - .append(queryFilter.getOperator().getValue()) - .append(" ") - .append(queryFilter.getValue()) - .append(" "); + textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ") + .append(queryFilter.getOperator().getValue()).append(" ") + .append(queryFilter.getValue()).append(" "); } } parseInfo.setTextInfo(textBuilder.toString()); @@ -187,8 +164,7 @@ public class NL2SQLParser implements ChatQueryParser { Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); Boolean multiTurnConfig = - agentMultiTurnConfig != null - ? agentMultiTurnConfig.isEnableMultiTurn() + agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn() : globalMultiTurnConfig; if (!Boolean.TRUE.equals(multiTurnConfig)) { return; @@ -232,30 +208,20 @@ public class NL2SQLParser implements ChatQueryParser { QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); MapResp rewrittenQueryMapResult = chatLayerService.performMapping(rewrittenQueryNLReq); parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo()); - log.info( - "Last Query: {} Current Query: {}, Rewritten Query: {}", - lastQuery.getQueryText(), - currentMapResult.getQueryText(), - rewrittenQuery); + log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(), + currentMapResult.getQueryText(), rewrittenQuery); } - private String rewriteErrorMessage( - String userQuestion, - String errMsg, - List similarExemplars, - List agentExamples, + private String rewriteErrorMessage(String userQuestion, String errMsg, + List similarExemplars, List agentExamples, ChatModelConfig modelConfig) { Map variables = new HashMap<>(); variables.put("user_question", userQuestion); variables.put("system_message", errMsg); StringBuilder exampleStr = new StringBuilder(); - similarExemplars.forEach( - e -> - exampleStr.append( - String.format( - " ", - e.getQuestion(), e.getDbSchema()))); + similarExemplars.forEach(e -> exampleStr.append( + String.format(" ", e.getQuestion(), e.getDbSchema()))); agentExamples.forEach(e -> exampleStr.append(String.format(" ", e))); variables.put("examples", exampleStr); @@ -297,18 +263,13 @@ public class NL2SQLParser implements ChatQueryParser { private List getHistoryQueries(int chatId, int multiNum) { ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class); - List contextualParseInfoList = - chatManageService.getChatQueries(chatId).stream() - .filter( - q -> - Objects.nonNull(q.getQueryResult()) - && q.getQueryResult().getQueryState() - == QueryState.SUCCESS) - .collect(Collectors.toList()); + List contextualParseInfoList = chatManageService.getChatQueries(chatId).stream() + .filter(q -> Objects.nonNull(q.getQueryResult()) + && q.getQueryResult().getQueryState() == QueryState.SUCCESS) + .collect(Collectors.toList()); - List contextualList = - contextualParseInfoList.subList( - 0, Math.min(multiNum, contextualParseInfoList.size())); + List contextualList = contextualParseInfoList.subList(0, + Math.min(multiNum, contextualParseInfoList.size())); Collections.reverse(contextualList); return contextualList; } @@ -320,9 +281,8 @@ public class NL2SQLParser implements ChatQueryParser { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); int exemplarRecallNumber = Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); - List exemplars = - exemplarManager.recallExemplars( - memoryCollectionName, queryNLReq.getQueryText(), exemplarRecallNumber); + List exemplars = exemplarManager.recallExemplars(memoryCollectionName, + queryNLReq.getQueryText(), exemplarRecallNumber); queryNLReq.getDynamicExemplars().addAll(exemplars); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java index b19dc074f..b3f5e703d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/parser/ParserConfig.java @@ -10,11 +10,6 @@ import org.springframework.stereotype.Service; public class ParserConfig extends ParameterConfig { public static final Parameter PARSER_MULTI_TURN_ENABLE = - new Parameter( - "s2.parser.multi-turn.enable", - "false", - "是否开启多轮对话", - "开启多轮对话将消耗更多token", - "bool", - "Parser相关配置"); + new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token", + "bool", "Parser相关配置"); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/CostType.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/CostType.java index 35cd8c696..533fde779 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/CostType.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/dataobject/CostType.java @@ -1,10 +1,7 @@ package com.tencent.supersonic.chat.server.persistence.dataobject; public enum CostType { - MAPPER(1, "mapper"), - PARSER(2, "parser"), - QUERY(3, "query"), - PROCESSOR(4, "processor"); + MAPPER(1, "mapper"), PARSER(2, "parser"), QUERY(3, "query"), PROCESSOR(4, "processor"); private Integer type; private String name; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java index 4b751135e..a833397d0 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/AgentDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface AgentDOMapper extends BaseMapper {} +public interface AgentDOMapper extends BaseMapper { +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java index 6f4341c82..50393eb37 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatMemoryMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ChatMemoryMapper extends BaseMapper {} +public interface ChatMemoryMapper extends BaseMapper { +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatQueryDOMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatQueryDOMapper.java index 95ccc4ae9..9b7ce6e6a 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatQueryDOMapper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/ChatQueryDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ChatQueryDOMapper extends BaseMapper {} +public interface ChatQueryDOMapper extends BaseMapper { +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/PluginDOMapper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/PluginDOMapper.java index 21b675d77..356951e3b 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/PluginDOMapper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/mapper/PluginDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface PluginDOMapper extends BaseMapper {} +public interface PluginDOMapper extends BaseMapper { +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java index 0b18d5a80..5ed7b45e1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/ChatQueryRepository.java @@ -30,9 +30,7 @@ public interface ChatQueryRepository { Long createChatQuery(ChatParseReq chatParseReq); - List batchSaveParseInfo( - ChatParseReq chatParseReq, - ParseResp parseResult, + List batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult, List candidateParses); ChatParseDO getParseInfo(Long questionId, int parseId); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatConfigRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatConfigRepositoryImpl.java index b91db2332..9994bd72e 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatConfigRepositoryImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatConfigRepositoryImpl.java @@ -23,8 +23,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository { private final ChatConfigHelper chatConfigHelper; private final ChatConfigMapper chatConfigMapper; - public ChatConfigRepositoryImpl( - ChatConfigHelper chatConfigHelper, ChatConfigMapper chatConfigMapper) { + public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper, + ChatConfigMapper chatConfigMapper) { this.chatConfigHelper = chatConfigHelper; this.chatConfigMapper = chatConfigMapper; } @@ -52,11 +52,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository { List chaConfigDOList = chatConfigMapper.search(filterInternal); if (!CollectionUtils.isEmpty(chaConfigDOList)) { chaConfigDOList.stream() - .forEach( - chaConfigDO -> - chaConfigDescriptorList.add( - chatConfigHelper.chatConfigDO2Descriptor( - chaConfigDO.getModelId(), chaConfigDO))); + .forEach(chaConfigDO -> chaConfigDescriptorList.add(chatConfigHelper + .chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO))); } return chaConfigDescriptorList; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java index 639e42209..1c03a9af8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/persistence/repository/impl/ChatQueryRepositoryImpl.java @@ -40,11 +40,14 @@ import java.util.stream.Collectors; @Slf4j public class ChatQueryRepositoryImpl implements ChatQueryRepository { - @Autowired private ChatQueryDOMapper chatQueryDOMapper; + @Autowired + private ChatQueryDOMapper chatQueryDOMapper; - @Autowired private ChatParseMapper chatParseMapper; + @Autowired + private ChatParseMapper chatParseMapper; - @Autowired private ShowCaseCustomMapper showCaseCustomMapper; + @Autowired + private ShowCaseCustomMapper showCaseCustomMapper; @Override public PageInfo getChatQuery(PageQueryInfoReq pageQueryInfoReq, Long chatId) { @@ -67,11 +70,9 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { .doSelectPageInfo(() -> chatQueryDOMapper.selectList(queryWrapper)); PageInfo chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo); - chatQueryVOPageInfo.setList( - pageInfo.getList().stream() - .sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue())) - .map(this::convertTo) - .collect(Collectors.toList())); + chatQueryVOPageInfo.setList(pageInfo.getList().stream() + .sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue())) + .map(this::convertTo).collect(Collectors.toList())); return chatQueryVOPageInfo; } @@ -94,22 +95,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId); queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId); - return chatQueryDOMapper.selectList(queryWrapper).stream() - .map(q -> convertTo(q)) + return chatQueryDOMapper.selectList(queryWrapper).stream().map(q -> convertTo(q)) .collect(Collectors.toList()); } @Override public List queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) { return showCaseCustomMapper - .queryShowCase( - pageQueryInfoReq.getLimitStart(), - pageQueryInfoReq.getPageSize(), - agentId, - pageQueryInfoReq.getUserName()) - .stream() - .map(this::convertTo) - .collect(Collectors.toList()); + .queryShowCase(pageQueryInfoReq.getLimitStart(), pageQueryInfoReq.getPageSize(), + agentId, pageQueryInfoReq.getUserName()) + .stream().map(this::convertTo).collect(Collectors.toList()); } private QueryResp convertTo(ChatQueryDO chatQueryDO) { @@ -121,9 +116,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { queryResult.setQueryId(chatQueryDO.getQuestionId()); queryResp.setQueryResult(queryResult); } - queryResp.setSimilarQueries( - JSONObject.parseArray( - chatQueryDO.getSimilarQueries(), SimilarQueryRecallResp.class)); + queryResp.setSimilarQueries(JSONObject.parseArray(chatQueryDO.getSimilarQueries(), + SimilarQueryRecallResp.class)); queryResp.setParseTimeCost( JsonUtil.toObject(chatQueryDO.getParseTimeCost(), ParseTimeCostResp.class)); return queryResp; @@ -147,9 +141,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { } @Override - public List batchSaveParseInfo( - ChatParseReq chatParseReq, - ParseResp parseResult, + public List batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult, List candidateParses) { List chatParseDOList = new ArrayList<>(); getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList); @@ -159,11 +151,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository { return chatParseDOList; } - public void getChatParseDO( - ChatParseReq chatParseReq, - Long queryId, - List parses, - List chatParseDOList) { + public void getChatParseDO(ChatParseReq chatParseReq, Long queryId, + List parses, List chatParseDOList) { for (int i = 0; i < parses.size(); i++) { ChatParseDO chatParseDO = new ChatParseDO(); chatParseDO.setChatId(chatParseReq.getChatId()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ParseMode.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ParseMode.java index 6e5a7b393..42da5e9ef 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ParseMode.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/ParseMode.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.chat.server.plugin; public enum ParseMode { - EMBEDDING_RECALL, - FUNCTION_CALL; + EMBEDDING_RECALL, FUNCTION_CALL; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java index 19c92c388..35d3598db 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/PluginManager.java @@ -46,9 +46,11 @@ import java.util.stream.Collectors; @Component public class PluginManager { - @Autowired private EmbeddingConfig embeddingConfig; + @Autowired + private EmbeddingConfig embeddingConfig; - @Autowired private EmbeddingService embeddingService; + @Autowired + private EmbeddingService embeddingService; public static List getPluginAgentCanSupport(ParseContext parseContext) { PluginService pluginService = ContextUtils.getBean(PluginService.class); @@ -57,21 +59,14 @@ public class PluginManager { if (Objects.isNull(agent)) { return plugins; } - List pluginIds = - getPluginTools(agent).stream() - .map(PluginTool::getPlugins) - .flatMap(Collection::stream) - .collect(Collectors.toList()); + List pluginIds = getPluginTools(agent).stream().map(PluginTool::getPlugins) + .flatMap(Collection::stream).collect(Collectors.toList()); if (CollectionUtils.isEmpty(pluginIds)) { return Lists.newArrayList(); } - plugins = - plugins.stream() - .filter(plugin -> pluginIds.contains(plugin.getId())) - .collect(Collectors.toList()); - log.info( - "plugins witch can be supported by cur agent :{} {}", - agent.getName(), + plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId())) + .collect(Collectors.toList()); + log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(), plugins.stream().map(ChatPlugin::getName).collect(Collectors.toList())); return plugins; } @@ -84,8 +79,7 @@ public class PluginManager { if (CollectionUtils.isEmpty(tools)) { return Lists.newArrayList(); } - return tools.stream() - .map(tool -> JSONObject.parseObject(tool, PluginTool.class)) + return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class)) .collect(Collectors.toList()); } @@ -142,23 +136,18 @@ public class PluginManager { public RetrieveQueryResult recognize(String embeddingText) { - RetrieveQuery retrieveQuery = - RetrieveQuery.builder() - .queryTextsList(Collections.singletonList(embeddingText)) - .build(); + RetrieveQuery retrieveQuery = RetrieveQuery.builder() + .queryTextsList(Collections.singletonList(embeddingText)).build(); - List resultList = - embeddingService.retrieveQuery( - embeddingConfig.getPresetCollection(), - retrieveQuery, - embeddingConfig.getNResult()); + List resultList = embeddingService.retrieveQuery( + embeddingConfig.getPresetCollection(), retrieveQuery, embeddingConfig.getNResult()); if (CollectionUtils.isNotEmpty(resultList)) { for (RetrieveQueryResult embeddingResp : resultList) { List embeddingRetrievals = embeddingResp.getRetrieval(); for (Retrieval embeddingRetrieval : embeddingRetrievals) { - embeddingRetrieval.setId( - getPluginIdFromEmbeddingId(embeddingRetrieval.getId())); + embeddingRetrieval + .setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId())); } } return resultList.get(0); @@ -173,8 +162,8 @@ public class PluginManager { int num = 0; for (String pattern : exampleQuestions) { TextSegment query = TextSegment.from(pattern); - TextSegmentConvert.addQueryId( - query, generateUniqueEmbeddingId(num, plugin.getId())); + TextSegmentConvert.addQueryId(query, + generateUniqueEmbeddingId(num, plugin.getId())); queries.add(query); num++; } @@ -250,14 +239,10 @@ public class PluginManager { return Sets.newHashSet(); } return schemaElementMatches.stream() - .filter( - schemaElementMatch -> - SchemaElementType.VALUE.equals( - schemaElementMatch.getElement().getType()) - || SchemaElementType.ID.equals( - schemaElementMatch.getElement().getType())) - .map(SchemaElementMatch::getElement) - .map(SchemaElement::getId) + .filter(schemaElementMatch -> SchemaElementType.VALUE + .equals(schemaElementMatch.getElement().getType()) + || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) + .map(SchemaElementMatch::getElement).map(SchemaElement::getId) .collect(Collectors.toSet()); } @@ -270,10 +255,8 @@ public class PluginManager { if (CollectionUtils.isEmpty(paramOptions)) { return Lists.newArrayList(); } - return paramOptions.stream() - .filter( - paramOption -> - ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) + return paramOptions.stream().filter( + paramOption -> ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType())) .collect(Collectors.toList()); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java index 580254077..c0a613f54 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/ParamOption.java @@ -26,13 +26,10 @@ public class ParamOption { * forward */ public enum ParamType { - CUSTOM, - SEMANTIC, - FORWARD + CUSTOM, SEMANTIC, FORWARD } public enum OptionType { - REQUIRED, - OPTIONAL + REQUIRED, OPTIONAL } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java index df5b4c9e8..5f076545c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/PluginSemanticQuery.java @@ -43,40 +43,31 @@ public abstract class PluginSemanticQuery { protected Map getElementMap(PluginParseResult pluginParseResult) { Map elementValueMap = new HashMap<>(); Map filterValueMap = getFilterMap(pluginParseResult); - List schemaElementMatchList = - parseInfo.getElementMatches().stream() - .filter(schemaElementMatch -> schemaElementMatch.getFrequency() != null) - .sorted( - Comparator.comparingLong(SchemaElementMatch::getFrequency) - .reversed()) - .collect(Collectors.toList()); + List schemaElementMatchList = parseInfo.getElementMatches().stream() + .filter(schemaElementMatch -> schemaElementMatch.getFrequency() != null) + .sorted(Comparator.comparingLong(SchemaElementMatch::getFrequency).reversed()) + .collect(Collectors.toList()); if (!CollectionUtils.isEmpty(schemaElementMatchList)) { - schemaElementMatchList.stream() - .filter( - schemaElementMatch -> - SchemaElementType.VALUE.equals( - schemaElementMatch.getElement().getType()) - || SchemaElementType.ID.equals( - schemaElementMatch.getElement().getType())) + schemaElementMatchList.stream().filter(schemaElementMatch -> SchemaElementType.VALUE + .equals(schemaElementMatch.getElement().getType()) + || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) .filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0) - .forEach( - schemaElementMatch -> { - Object queryFilterValue = - filterValueMap.get(schemaElementMatch.getElement().getId()); - if (queryFilterValue != null) { - if (String.valueOf(queryFilterValue) - .equals(String.valueOf(schemaElementMatch.getWord()))) { - elementValueMap.put( - String.valueOf( - schemaElementMatch.getElement().getId()), - schemaElementMatch.getWord()); - } - } else { - elementValueMap.computeIfAbsent( - String.valueOf(schemaElementMatch.getElement().getId()), - k -> schemaElementMatch.getWord()); - } - }); + .forEach(schemaElementMatch -> { + Object queryFilterValue = + filterValueMap.get(schemaElementMatch.getElement().getId()); + if (queryFilterValue != null) { + if (String.valueOf(queryFilterValue) + .equals(String.valueOf(schemaElementMatch.getWord()))) { + elementValueMap.put( + String.valueOf(schemaElementMatch.getElement().getId()), + schemaElementMatch.getWord()); + } + } else { + elementValueMap.computeIfAbsent( + String.valueOf(schemaElementMatch.getElement().getId()), + k -> schemaElementMatch.getWord()); + } + }); } return elementValueMap; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java index 7af00cad6..9902ee566 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webpage/WebPageQuery.java @@ -41,10 +41,8 @@ public class WebPageQuery extends PluginSemanticQuery { QueryResult queryResult = new QueryResult(); queryResult.setQueryMode(QUERY_MODE); Map properties = parseInfo.getProperties(); - PluginParseResult pluginParseResult = - JsonUtil.toObject( - JsonUtil.toString(properties.get(Constants.CONTEXT)), - PluginParseResult.class); + PluginParseResult pluginParseResult = JsonUtil.toObject( + JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class); WebPageResp webPageResponse = buildResponse(pluginParseResult); queryResult.setResponse(webPageResponse); queryResult.setQueryState(QueryState.SUCCESS); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java index 87c57f5bc..efefce08d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/build/webservice/WebServiceQuery.java @@ -45,10 +45,8 @@ public class WebServiceQuery extends PluginSemanticQuery { QueryResult queryResult = new QueryResult(); queryResult.setQueryMode(QUERY_MODE); Map properties = parseInfo.getProperties(); - PluginParseResult pluginParseResult = - JsonUtil.toObject( - JsonUtil.toString(properties.get(Constants.CONTEXT)), - PluginParseResult.class); + PluginParseResult pluginParseResult = JsonUtil.toObject( + JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class); WebServiceResp webServiceResponse = buildResponse(pluginParseResult); Object object = webServiceResponse.getResult(); // in order to show webServiceQuery result int frontend conveniently, @@ -74,9 +72,8 @@ public class WebServiceQuery extends PluginSemanticQuery { protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) { WebServiceResp webServiceResponse = new WebServiceResp(); ChatPlugin plugin = pluginParseResult.getPlugin(); - WebBase webBase = - fillWebBaseResult( - JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult); + WebBase webBase = fillWebBaseResult(JsonUtil.toObject(plugin.getConfig(), WebBase.class), + pluginParseResult); webServiceResponse.setWebBase(webBase); List paramOptions = webBase.getParamOptions(); Map params = new HashMap<>(); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java index 74dde3215..c4d2935e1 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/PluginRecognizer.java @@ -41,17 +41,16 @@ public abstract class PluginRecognizer { public abstract PluginRecallResult recallPlugin(ParseContext parseContext); - public void buildQuery( - ParseContext parseContext, ParseResp parseResp, PluginRecallResult pluginRecallResult) { + public void buildQuery(ParseContext parseContext, ParseResp parseResp, + PluginRecallResult pluginRecallResult) { ChatPlugin plugin = pluginRecallResult.getPlugin(); Set dataSetIds = pluginRecallResult.getDataSetIds(); if (plugin.isContainsAllDataSet()) { dataSetIds = Sets.newHashSet(-1L); } for (Long dataSetId : dataSetIds) { - SemanticParseInfo semanticParseInfo = - buildSemanticParseInfo( - dataSetId, plugin, parseContext, pluginRecallResult.getDistance()); + SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin, + parseContext, pluginRecallResult.getDistance()); semanticParseInfo.setQueryMode(plugin.getType()); semanticParseInfo.setScore(pluginRecallResult.getScore()); parseResp.getSelectedParses().add(semanticParseInfo); @@ -62,8 +61,8 @@ public abstract class PluginRecognizer { return PluginManager.getPluginAgentCanSupport(parseContext); } - protected SemanticParseInfo buildSemanticParseInfo( - Long dataSetId, ChatPlugin plugin, ParseContext parseContext, double distance) { + protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin, + ParseContext parseContext, double distance) { List schemaElementMatches = parseContext.getMapInfo().getMatchedElements(dataSetId); QueryFilters queryFilters = parseContext.getQueryFilters(); @@ -97,21 +96,17 @@ public abstract class PluginRecognizer { return; } schemaElementMatches.stream() - .filter( - schemaElementMatch -> - SchemaElementType.VALUE.equals( - schemaElementMatch.getElement().getType()) - || SchemaElementType.ID.equals( - schemaElementMatch.getElement().getType())) - .forEach( - schemaElementMatch -> { - QueryFilter queryFilter = new QueryFilter(); - queryFilter.setValue(schemaElementMatch.getWord()); - queryFilter.setElementID(schemaElementMatch.getElement().getId()); - queryFilter.setName(schemaElementMatch.getElement().getName()); - queryFilter.setOperator(FilterOperatorEnum.EQUALS); - queryFilter.setBizName(schemaElementMatch.getElement().getBizName()); - semanticParseInfo.getDimensionFilters().add(queryFilter); - }); + .filter(schemaElementMatch -> SchemaElementType.VALUE + .equals(schemaElementMatch.getElement().getType()) + || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType())) + .forEach(schemaElementMatch -> { + QueryFilter queryFilter = new QueryFilter(); + queryFilter.setValue(schemaElementMatch.getWord()); + queryFilter.setElementID(schemaElementMatch.getElement().getId()); + queryFilter.setName(schemaElementMatch.getElement().getName()); + queryFilter.setOperator(FilterOperatorEnum.EQUALS); + queryFilter.setBizName(schemaElementMatch.getElement().getBizName()); + semanticParseInfo.getDimensionFilters().add(queryFilter); + }); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java index e0b1e3f1e..589192c45 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/plugin/recognize/embedding/EmbeddingRecallRecognizer.java @@ -53,12 +53,8 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { plugin.setParseMode(ParseMode.EMBEDDING_RECALL); double similarity = embeddingRetrieval.getSimilarity(); double score = parseContext.getQueryText().length() * similarity; - return PluginRecallResult.builder() - .plugin(plugin) - .dataSetIds(dataSetList) - .score(score) - .distance(similarity) - .build(); + return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList) + .score(score).distance(similarity).build(); } } return null; @@ -71,12 +67,9 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer { List embeddingRetrievals = embeddingResp.getRetrieval(); if (!CollectionUtils.isEmpty(embeddingRetrievals)) { - embeddingRetrievals = - embeddingRetrievals.stream() - .sorted( - Comparator.comparingDouble( - o -> Math.abs(o.getSimilarity()))) - .collect(Collectors.toList()); + embeddingRetrievals = embeddingRetrievals.stream() + .sorted(Comparator.comparingDouble(o -> Math.abs(o.getSimilarity()))) + .collect(Collectors.toList()); embeddingResp.setRetrieval(embeddingRetrievals); } return embeddingRetrievals; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java index 5075e863c..bda8d71ca 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/ResultProcessor.java @@ -1,4 +1,5 @@ package com.tencent.supersonic.chat.server.processor; /** A ResultProcessor wraps things up before returning results to users. */ -public interface ResultProcessor {} +public interface ResultProcessor { +} diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DimensionRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DimensionRecommendProcessor.java index 3031fe366..d96a75996 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DimensionRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/DimensionRecommendProcessor.java @@ -52,28 +52,20 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor { List drillDownDimensions = Lists.newArrayList(); Set metricElements = dataSetSchema.getMetrics(); if (!CollectionUtils.isEmpty(metricElements)) { - Optional metric = - metricElements.stream() - .filter( - schemaElement -> - metricId.equals(schemaElement.getId()) - && !CollectionUtils.isEmpty( - schemaElement - .getRelatedSchemaElements())) - .findFirst(); + Optional metric = metricElements.stream() + .filter(schemaElement -> metricId.equals(schemaElement.getId()) + && !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements())) + .findFirst(); if (metric.isPresent()) { - drillDownDimensions = - metric.get().getRelatedSchemaElements().stream() - .map(RelatedSchemaElement::getDimensionId) - .collect(Collectors.toList()); + drillDownDimensions = metric.get().getRelatedSchemaElements().stream() + .map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList()); } } final List drillDownDimensionsFinal = drillDownDimensions; return dataSetSchema.getDimensions().stream() .filter(dim -> filterDimension(drillDownDimensionsFinal, dim)) .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(recommend_dimension_size) - .collect(Collectors.toList()); + .limit(recommend_dimension_size).collect(Collectors.toList()); } private boolean filterDimension(List drillDownDimensions, SchemaElement dimension) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java index 559f5f602..501193169 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRatioProcessor.java @@ -69,19 +69,14 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { queryResult.setAggregateInfo(aggregateInfo); } - public AggregateInfo getAggregateInfo( - User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) { + public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo, + QueryResult queryResult) { Set resultMetricNames = new HashSet<>(); - queryResult.getQueryColumns().stream() - .forEach( - c -> - resultMetricNames.addAll( - SqlSelectHelper.getColumnFromExpr(c.getNameEn()))); - Optional ratioMetric = - semanticParseInfo.getMetrics().stream() - .filter(m -> resultMetricNames.contains(m.getBizName())) - .findFirst(); + queryResult.getQueryColumns().stream().forEach( + c -> resultMetricNames.addAll(SqlSelectHelper.getColumnFromExpr(c.getNameEn()))); + Optional ratioMetric = semanticParseInfo.getMetrics().stream() + .filter(m -> resultMetricNames.contains(m.getBizName())).findFirst(); AggregateInfo aggregateInfo = new AggregateInfo(); if (!ratioMetric.isPresent()) { @@ -90,20 +85,15 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { try { String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); - Optional lastDayOp = - queryResult.getQueryResults().stream() - .filter(r -> r.containsKey(dateField)) - .map(r -> r.get(dateField).toString()) - .sorted(Comparator.reverseOrder()) - .findFirst(); + Optional lastDayOp = queryResult.getQueryResults().stream() + .filter(r -> r.containsKey(dateField)).map(r -> r.get(dateField).toString()) + .sorted(Comparator.reverseOrder()).findFirst(); if (!lastDayOp.isPresent()) { return new AggregateInfo(); } - Optional> lastValue = - queryResult.getQueryResults().stream() - .filter(r -> r.get(dateField).toString().equals(lastDayOp.get())) - .findFirst(); + Optional> lastValue = queryResult.getQueryResults().stream() + .filter(r -> r.get(dateField).toString().equals(lastDayOp.get())).findFirst(); MetricInfo metricInfo = new MetricInfo(); metricInfo.setStatistics(new HashMap<>()); @@ -115,23 +105,11 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { metricInfo.setDate(lastValue.get().get(dateField).toString()); CompletableFuture metricInfoRoll = - CompletableFuture.supplyAsync( - () -> - queryRatio( - user, - semanticParseInfo, - ratioMetric.get(), - AggOperatorEnum.RATIO_ROLL, - queryResult)); + CompletableFuture.supplyAsync(() -> queryRatio(user, semanticParseInfo, + ratioMetric.get(), AggOperatorEnum.RATIO_ROLL, queryResult)); CompletableFuture metricInfoOver = - CompletableFuture.supplyAsync( - () -> - queryRatio( - user, - semanticParseInfo, - ratioMetric.get(), - AggOperatorEnum.RATIO_OVER, - queryResult)); + CompletableFuture.supplyAsync(() -> queryRatio(user, semanticParseInfo, + ratioMetric.get(), AggOperatorEnum.RATIO_OVER, queryResult)); CompletableFuture.allOf(metricInfoRoll, metricInfoOver); metricInfo.setName(metricInfoRoll.get().getName()); metricInfo.setValue(metricInfoRoll.get().getValue()); @@ -145,19 +123,15 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { } @SneakyThrows - private MetricInfo queryRatio( - User user, - SemanticParseInfo semanticParseInfo, - SchemaElement metric, - AggOperatorEnum aggOperatorEnum, - QueryResult queryResult) { + private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo, + SchemaElement metric, AggOperatorEnum aggOperatorEnum, QueryResult queryResult) { QueryStructReq queryStructReq = QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum); String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField))); - queryStructReq.setDateInfo( - getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult)); + queryStructReq + .setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult)); queryStructReq.setConvertToSql(false); SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); SemanticQueryResp queryResp = queryService.queryByReq(queryStructReq, user); @@ -168,26 +142,22 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { } Map result = queryResp.getResultList().get(0); - Optional valueColumn = - queryResp.getColumns().stream() - .filter(c -> c.getNameEn().equals(metric.getBizName())) - .findFirst(); + Optional valueColumn = queryResp.getColumns().stream() + .filter(c -> c.getNameEn().equals(metric.getBizName())).findFirst(); if (!valueColumn.isPresent()) { return metricInfo; } - String valueField = - String.format( - "%s_%s", valueColumn.get().getNameEn(), aggOperatorEnum.getOperator()); + String valueField = String.format("%s_%s", valueColumn.get().getNameEn(), + aggOperatorEnum.getOperator()); if (result.containsKey(valueColumn.get().getNameEn())) { DecimalFormat df = new DecimalFormat("#.####"); metricInfo.setValue(df.format(result.get(valueColumn.get().getNameEn()))); } String ratio = ""; if (Objects.nonNull(result.get(valueField))) { - ratio = - String.format("%.2f", (Double.valueOf(result.get(valueField).toString()) * 100)) - + "%"; + ratio = String.format("%.2f", (Double.valueOf(result.get(valueField).toString()) * 100)) + + "%"; } String statisticsRollName = RatioOverType.DAY_ON_DAY.getShowName(); String statisticsOverName = RatioOverType.WEEK_ON_DAY.getShowName(); @@ -199,28 +169,20 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { statisticsRollName = RatioOverType.WEEK_ON_WEEK.getShowName(); statisticsOverName = RatioOverType.MONTH_ON_WEEK.getShowName(); } - metricInfo - .getStatistics() - .put( - aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) - ? statisticsRollName - : statisticsOverName, - ratio); + metricInfo.getStatistics() + .put(aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? statisticsRollName + : statisticsOverName, ratio); metricInfo.setName(metric.getName()); return metricInfo; } - private DateConf getRatioDateConf( - AggOperatorEnum aggOperatorEnum, - SemanticParseInfo semanticParseInfo, - QueryResult queryResult) { + private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum, + SemanticParseInfo semanticParseInfo, QueryResult queryResult) { String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); Optional lastDayOp = - queryResult.getQueryResults().stream() - .map(r -> r.get(dateField).toString()) - .sorted(Comparator.reverseOrder()) - .findFirst(); + queryResult.getQueryResults().stream().map(r -> r.get(dateField).toString()) + .sorted(Comparator.reverseOrder()).findFirst(); if (!lastDayOp.isPresent()) { return semanticParseInfo.getDateInfo(); @@ -236,31 +198,25 @@ public class MetricRatioProcessor implements ExecuteResultProcessor { DateTimeFormatter formatter = DateUtils.getDateFormatter(lastDay, new String[] {DAY_FORMAT, DAY_FORMAT_INT}); LocalDate end = LocalDate.parse(lastDay, formatter); - start = - aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) - ? end.minusDays(1).format(formatter) - : end.minusWeeks(1).format(formatter); + start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) + ? end.minusDays(1).format(formatter) + : end.minusWeeks(1).format(formatter); } if (DatePeriodEnum.WEEK.equals(semanticParseInfo.getDateInfo().getPeriod())) { - DateTimeFormatter formatter = - DateUtils.getTimeFormatter( - lastDay, - new String[] {TIMES_FORMAT, DAY_FORMAT, TIME_FORMAT, DAY_FORMAT_INT}); + DateTimeFormatter formatter = DateUtils.getTimeFormatter(lastDay, + new String[] {TIMES_FORMAT, DAY_FORMAT, TIME_FORMAT, DAY_FORMAT_INT}); LocalDateTime end = LocalDateTime.parse(lastDay, formatter); - start = - aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) - ? end.minusWeeks(1).format(formatter) - : end.minusMonths(1).with(DayOfWeek.MONDAY).format(formatter); + start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) + ? end.minusWeeks(1).format(formatter) + : end.minusMonths(1).with(DayOfWeek.MONDAY).format(formatter); } if (DatePeriodEnum.MONTH.equals(semanticParseInfo.getDateInfo().getPeriod())) { - DateTimeFormatter formatter = - DateUtils.getDateFormatter( - lastDay, new String[] {MONTH_FORMAT, MONTH_FORMAT_INT}); + DateTimeFormatter formatter = DateUtils.getDateFormatter(lastDay, + new String[] {MONTH_FORMAT, MONTH_FORMAT_INT}); YearMonth end = YearMonth.parse(lastDay, formatter); - start = - aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) - ? end.minusMonths(1).format(formatter) - : end.minusYears(1).format(formatter); + start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) + ? end.minusMonths(1).format(formatter) + : end.minusYears(1).format(formatter); } dayList.add(start); dateConf.setDateList(dayList); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java index bc1bb884d..a108c86b6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/execute/MetricRecommendProcessor.java @@ -45,33 +45,24 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor { List metricNames = Collections.singletonList(parseInfo.getMetrics().iterator().next().getName()); Map filterCondition = new HashMap<>(); - filterCondition.put( - "modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString()); + filterCondition.put("modelId", + parseInfo.getMetrics().iterator().next().getDataSetId().toString()); filterCondition.put("type", SchemaElementType.METRIC.name()); - RetrieveQuery retrieveQuery = - RetrieveQuery.builder() - .queryTextsList(metricNames) - .filterCondition(filterCondition) - .queryEmbeddings(null) - .build(); + RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames) + .filterCondition(filterCondition).queryEmbeddings(null).build(); MetaEmbeddingService metaEmbeddingService = ContextUtils.getBean(MetaEmbeddingService.class); - List retrieveQueryResults = - metaEmbeddingService.retrieveQuery( - retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(), new HashSet<>()); + List retrieveQueryResults = metaEmbeddingService.retrieveQuery( + retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(), new HashSet<>()); if (CollectionUtils.isEmpty(retrieveQueryResults)) { return; } - List retrievals = - retrieveQueryResults.stream() - .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()) - .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)) - .distinct() - .collect(Collectors.toList()); - Set metricIds = - parseInfo.getMetrics().stream() - .map(SchemaElement::getId) - .collect(Collectors.toSet()); + List retrievals = retrieveQueryResults.stream() + .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()) + .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)).distinct() + .collect(Collectors.toList()); + Set metricIds = parseInfo.getMetrics().stream().map(SchemaElement::getId) + .collect(Collectors.toSet()); int metricOrder = 0; for (SchemaElement metric : parseInfo.getMetrics()) { metric.setOrder(metricOrder++); @@ -79,23 +70,15 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor { for (Retrieval retrieval : retrievals) { if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) { if (Objects.nonNull(retrieval.getMetadata().get("id"))) { - String idStr = - retrieval - .getMetadata() - .get("id") - .toString() - .replaceAll(DictWordType.NATURE_SPILT, ""); + String idStr = retrieval.getMetadata().get("id").toString() + .replaceAll(DictWordType.NATURE_SPILT, ""); retrieval.getMetadata().put("id", idStr); } String metaStr = JSONObject.toJSONString(retrieval.getMetadata()); SchemaElement schemaElement = JSONObject.parseObject(metaStr, SchemaElement.class); if (retrieval.getMetadata().containsKey("dataSetId")) { - String dataSetId = - retrieval - .getMetadata() - .get("dataSetId") - .toString() - .replace(Constants.UNDERLINE, ""); + String dataSetId = retrieval.getMetadata().get("dataSetId").toString() + .replace(Constants.UNDERLINE, ""); schemaElement.setDataSetId(Long.parseLong(dataSetId)); } schemaElement.setOrder(++metricOrder); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java index 50eb73f7e..96dfbb215 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/QueryRecommendProcessor.java @@ -43,13 +43,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor { String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); List exemplars = exemplarService.recallExemplars(memoryCollectionName, queryText, 5); - return exemplars.stream() - .map( - sqlExemplar -> - SimilarQueryRecallResp.builder() - .queryText(sqlExemplar.getQuestion()) - .build()) - .collect(Collectors.toList()); + return exemplars.stream().map(sqlExemplar -> SimilarQueryRecallResp.builder() + .queryText(sqlExemplar.getQuestion()).build()).collect(Collectors.toList()); } private ChatQueryDO getChatQuery(Long queryId) { diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java index d1854099f..6c2e346f4 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/TimeCostProcessor.java @@ -11,11 +11,7 @@ public class TimeCostProcessor implements ParseResultProcessor { @Override public void process(ParseContext parseContext, ParseResp parseResp) { long parseStartTime = parseResp.getParseTimeCost().getParseStartTime(); - parseResp - .getParseTimeCost() - .setParseTime( - System.currentTimeMillis() - - parseStartTime - - parseResp.getParseTimeCost().getSqlTime()); + parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime + - parseResp.getParseTimeCost().getSqlTime()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java index 62d8020ad..c5490d0a8 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/AgentController.java @@ -26,21 +26,18 @@ import java.util.Map; @RequestMapping({"/api/chat/agent", "/openapi/chat/agent"}) public class AgentController { - @Autowired private AgentService agentService; + @Autowired + private AgentService agentService; @PostMapping - public Agent createAgent( - @RequestBody Agent agent, - HttpServletRequest httpServletRequest, + public Agent createAgent(@RequestBody Agent agent, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return agentService.createAgent(agent, user); } @PutMapping - public Agent updateAgent( - @RequestBody Agent agent, - HttpServletRequest httpServletRequest, + public Agent updateAgent(@RequestBody Agent agent, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return agentService.updateAgent(agent, user); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatConfigController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatConfigController.java index 4fe5714f6..598482823 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatConfigController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatConfigController.java @@ -29,33 +29,29 @@ import java.util.List; @RequestMapping({"/api/chat/conf", "/openapi/chat/conf"}) public class ChatConfigController { - @Autowired private ConfigService configService; + @Autowired + private ConfigService configService; - @Autowired private SemanticLayerService semanticLayerService; + @Autowired + private SemanticLayerService semanticLayerService; @PostMapping - public Long addChatConfig( - @RequestBody ChatConfigBaseReq extendBaseCmd, - HttpServletRequest request, - HttpServletResponse response) { + public Long addChatConfig(@RequestBody ChatConfigBaseReq extendBaseCmd, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return configService.addConfig(extendBaseCmd, user); } @PutMapping - public Long editModelExtend( - @RequestBody ChatConfigEditReqReq extendEditCmd, - HttpServletRequest request, - HttpServletResponse response) { + public Long editModelExtend(@RequestBody ChatConfigEditReqReq extendEditCmd, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return configService.editConfig(extendEditCmd, user); } @PostMapping("/search") - public List search( - @RequestBody ChatConfigFilter filter, - HttpServletRequest request, - HttpServletResponse response) { + public List search(@RequestBody ChatConfigFilter filter, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return configService.search(filter, user); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java index d7794cd3f..c79ac2795 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatController.java @@ -25,14 +25,13 @@ import java.util.List; @RequestMapping({"/api/chat/manage", "/openapi/chat/manage"}) public class ChatController { - @Autowired private ChatManageService chatService; + @Autowired + private ChatManageService chatService; @PostMapping("/save") - public Boolean save( - @RequestParam(value = "chatName") String chatName, + public Boolean save(@RequestParam(value = "chatName") String chatName, @RequestParam(value = "agentId", required = false) Integer agentId, - HttpServletRequest request, - HttpServletResponse response) { + HttpServletRequest request, HttpServletResponse response) { chatService.addChat(UserHolder.findUser(request, response), chatName, agentId); return true; } @@ -40,50 +39,42 @@ public class ChatController { @GetMapping("/getAll") public List getAllConversions( @RequestParam(value = "agentId", required = false) Integer agentId, - HttpServletRequest request, - HttpServletResponse response) { + HttpServletRequest request, HttpServletResponse response) { String userName = UserHolder.findUser(request, response).getName(); return chatService.getAll(userName, agentId); } @PostMapping("/delete") - public Boolean deleteConversion( - @RequestParam(value = "chatId") long chatId, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean deleteConversion(@RequestParam(value = "chatId") long chatId, + HttpServletRequest request, HttpServletResponse response) { String userName = UserHolder.findUser(request, response).getName(); return chatService.deleteChat(chatId, userName); } @PostMapping("/updateChatName") - public Boolean updateConversionName( - @RequestParam(value = "chatId") Long chatId, - @RequestParam(value = "chatName") String chatName, - HttpServletRequest request, + public Boolean updateConversionName(@RequestParam(value = "chatId") Long chatId, + @RequestParam(value = "chatName") String chatName, HttpServletRequest request, HttpServletResponse response) { String userName = UserHolder.findUser(request, response).getName(); return chatService.updateChatName(chatId, chatName, userName); } @PostMapping("/updateQAFeedback") - public Boolean updateQAFeedback( - @RequestParam(value = "id") Integer id, + public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id, @RequestParam(value = "score") Integer score, @RequestParam(value = "feedback", required = false) String feedback) { return chatService.updateFeedback(id, score, feedback); } @PostMapping("/updateChatIsTop") - public Boolean updateConversionIsTop( - @RequestParam(value = "chatId") Long chatId, @RequestParam(value = "isTop") int isTop) { + public Boolean updateConversionIsTop(@RequestParam(value = "chatId") Long chatId, + @RequestParam(value = "isTop") int isTop) { return chatService.updateChatIsTop(chatId, isTop); } @PostMapping("/pageQueryInfo") - public PageInfo pageQueryInfo( - @RequestBody PageQueryInfoReq pageQueryInfoCommand, - @RequestParam(value = "chatId") long chatId, - HttpServletRequest request, + public PageInfo pageQueryInfo(@RequestBody PageQueryInfoReq pageQueryInfoCommand, + @RequestParam(value = "chatId") long chatId, HttpServletRequest request, HttpServletResponse response) { pageQueryInfoCommand.setUserName(UserHolder.findUser(request, response).getName()); return chatService.queryInfo(pageQueryInfoCommand, chatId); @@ -95,8 +86,7 @@ public class ChatController { } @PostMapping("/queryShowCase") - public ShowCaseResp queryShowCase( - @RequestBody PageQueryInfoReq pageQueryInfoCommand, + public ShowCaseResp queryShowCase(@RequestBody PageQueryInfoReq pageQueryInfoCommand, @RequestParam(value = "agentId") int agentId) { return chatService.queryShowCase(pageQueryInfoCommand, agentId); } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java index e38ef5940..cba4b6887 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/ChatQueryController.java @@ -27,43 +27,33 @@ import org.springframework.web.bind.annotation.RestController; @RequestMapping({"/api/chat/query", "/openapi/chat/query"}) public class ChatQueryController { - @Autowired private ChatQueryService chatQueryService; + @Autowired + private ChatQueryService chatQueryService; @PostMapping("search") - public Object search( - @RequestBody ChatParseReq chatParseReq, - HttpServletRequest request, + public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request, HttpServletResponse response) { chatParseReq.setUser(UserHolder.findUser(request, response)); return chatQueryService.search(chatParseReq); } @PostMapping("parse") - public Object parse( - @RequestBody ChatParseReq chatParseReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object parse(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { chatParseReq.setUser(UserHolder.findUser(request, response)); return chatQueryService.performParsing(chatParseReq); } @PostMapping("execute") - public Object execute( - @RequestBody ChatExecuteReq chatExecuteReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object execute(@RequestBody ChatExecuteReq chatExecuteReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { chatExecuteReq.setUser(UserHolder.findUser(request, response)); return chatQueryService.performExecution(chatExecuteReq); } @PostMapping("/") - public Object query( - @RequestBody ChatParseReq chatParseReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object query(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); chatParseReq.setUser(user); ParseResp parseResp = chatQueryService.performParsing(chatParseReq); @@ -80,22 +70,16 @@ public class ChatQueryController { } @PostMapping("queryData") - public Object queryData( - @RequestBody ChatQueryDataReq chatQueryDataReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { chatQueryDataReq.setUser(UserHolder.findUser(request, response)); return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response)); } @PostMapping("queryDimensionValue") - public Object queryDimensionValue( - @RequestBody @Valid DimensionValueReq dimensionValueReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { - return chatQueryService.queryDimensionValue( - dimensionValueReq, UserHolder.findUser(request, response)); + public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { + return chatQueryService.queryDimensionValue(dimensionValueReq, + UserHolder.findUser(request, response)); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java index 2f27543ca..5a720dfc7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/MemoryController.java @@ -21,13 +21,12 @@ import org.springframework.web.bind.annotation.RestController; @RequestMapping({"/api/chat/memory"}) public class MemoryController { - @Autowired private MemoryService memoryService; + @Autowired + private MemoryService memoryService; @PostMapping("/updateMemory") - public Boolean updateMemory( - @RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); memoryService.updateMemory(chatMemoryUpdateReq, user); return true; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java index 9f4bf6390..53a9470f7 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/rest/PluginController.java @@ -25,23 +25,20 @@ import java.util.List; @RequestMapping("/api/chat/plugin") public class PluginController { - @Autowired protected PluginService pluginService; + @Autowired + protected PluginService pluginService; @PostMapping - public boolean createPlugin( - @RequestBody ChatPlugin plugin, - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + public boolean createPlugin(@RequestBody ChatPlugin plugin, + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); pluginService.createPlugin(plugin, user); return true; } @PutMapping - public boolean updatePlugin( - @RequestBody ChatPlugin plugin, - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + public boolean updatePlugin(@RequestBody ChatPlugin plugin, + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); pluginService.updatePlugin(plugin, user); return true; @@ -59,18 +56,16 @@ public class PluginController { } @PostMapping("/query") - List query( - @RequestBody PluginQueryReq pluginQueryReq, - HttpServletRequest httpServletRequest, - HttpServletResponse httpServletResponse) { + List query(@RequestBody PluginQueryReq pluginQueryReq, + HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { User user = UserHolder.findUser(httpServletRequest, httpServletResponse); return pluginService.queryWithAuthCheck(pluginQueryReq, user); } @AuthenticationIgnore @PostMapping("/pluginDemo") - public String pluginDemo( - @RequestParam("queryText") String queryText, @RequestBody Object object) { + public String pluginDemo(@RequestParam("queryText") String queryText, + @RequestBody Object object) { return String.format("已收到您的问题:%s, 但这只是一个demo~", queryText); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java index 6719d08c6..2ea94f443 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/AgentServiceImpl.java @@ -33,9 +33,11 @@ import java.util.stream.Collectors; @Service public class AgentServiceImpl extends ServiceImpl implements AgentService { - @Autowired private MemoryService memoryService; + @Autowired + private MemoryService memoryService; - @Autowired private ChatQueryService chatQueryService; + @Autowired + private ChatQueryService chatQueryService; private ExecutorService executorService = Executors.newFixedThreadPool(1); @@ -98,8 +100,7 @@ public class AgentServiceImpl extends ServiceImpl implem } private synchronized void doExecuteAgentExamples(Agent agent) { - if (!agent.containsLLMTool() - || !LLMConnHelper.testConnection(agent.getModelConfig()) + if (!agent.containsLLMTool() || !LLMConnHelper.testConnection(agent.getModelConfig()) || CollectionUtils.isEmpty(agent.getExamples())) { return; } @@ -107,10 +108,8 @@ public class AgentServiceImpl extends ServiceImpl implem List examples = agent.getExamples(); ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build(); - List memoriesExisted = - memoryService.getMemories(chatMemoryFilter).stream() - .map(ChatMemoryDO::getQuestion) - .collect(Collectors.toList()); + List memoriesExisted = memoryService.getMemories(chatMemoryFilter).stream() + .map(ChatMemoryDO::getQuestion).collect(Collectors.toList()); for (String example : examples) { if (memoriesExisted.contains(example)) { continue; diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java index 77102cd10..a40cd3807 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatManageServiceImpl.java @@ -37,8 +37,10 @@ import java.util.stream.Collectors; @Service public class ChatManageServiceImpl implements ChatManageService { - @Autowired private ChatRepository chatRepository; - @Autowired private ChatQueryRepository chatQueryRepository; + @Autowired + private ChatRepository chatRepository; + @Autowired + private ChatQueryRepository chatQueryRepository; @Override public Long addChat(User user, String chatName, Integer agentId) { @@ -121,30 +123,23 @@ public class ChatManageServiceImpl implements ChatManageService { if (CollectionUtils.isEmpty(queryResps)) { return showCaseResp; } - queryResps.removeIf( - queryResp -> { - if (queryResp.getQueryResult() == null) { - return true; - } - if (queryResp.getQueryResult().getResponse() != null) { - return false; - } - if (CollectionUtils.isEmpty(queryResp.getQueryResult().getQueryResults())) { - return true; - } - Map data = queryResp.getQueryResult().getQueryResults().get(0); - return CollectionUtils.isEmpty(data); - }); - queryResps = - new ArrayList<>( - queryResps.stream() - .collect( - Collectors.toMap( - QueryResp::getQueryText, - Function.identity(), - (existing, replacement) -> existing, - LinkedHashMap::new)) - .values()); + queryResps.removeIf(queryResp -> { + if (queryResp.getQueryResult() == null) { + return true; + } + if (queryResp.getQueryResult().getResponse() != null) { + return false; + } + if (CollectionUtils.isEmpty(queryResp.getQueryResult().getQueryResults())) { + return true; + } + Map data = queryResp.getQueryResult().getQueryResults().get(0); + return CollectionUtils.isEmpty(data); + }); + queryResps = new ArrayList<>(queryResps.stream() + .collect(Collectors.toMap(QueryResp::getQueryText, Function.identity(), + (existing, replacement) -> existing, LinkedHashMap::new)) + .values()); fillParseInfo(queryResps); Map> showCaseMap = queryResps.stream().collect(Collectors.groupingBy(QueryResp::getChatId)); @@ -166,17 +161,11 @@ public class ChatManageServiceImpl implements ChatManageService { if (CollectionUtils.isEmpty(chatParseDOList)) { continue; } - List parseInfos = - chatParseDOList.stream() - .map( - chatParseDO -> - JsonUtil.toObject( - chatParseDO.getParseInfo(), - SemanticParseInfo.class)) - .sorted( - Comparator.comparingDouble(SemanticParseInfo::getScore) - .reversed()) - .collect(Collectors.toList()); + List parseInfos = chatParseDOList.stream() + .map(chatParseDO -> JsonUtil.toObject(chatParseDO.getParseInfo(), + SemanticParseInfo.class)) + .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) + .collect(Collectors.toList()); queryResp.setParseInfos(parseInfos); } } @@ -188,10 +177,8 @@ public class ChatManageServiceImpl implements ChatManageService { chatQueryDO.setQueryResult(JsonUtil.toString(queryResult)); chatQueryDO.setQueryState(1); updateQuery(chatQueryDO); - chatRepository.updateLastQuestion( - chatExecuteReq.getChatId().longValue(), - chatExecuteReq.getQueryText(), - getCurrentTime()); + chatRepository.updateLastQuestion(chatExecuteReq.getChatId().longValue(), + chatExecuteReq.getQueryText(), getCurrentTime()); return chatQueryDO; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java index dc4e67747..74488deea 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ChatQueryServiceImpl.java @@ -78,10 +78,14 @@ import java.util.stream.Collectors; @Service public class ChatQueryServiceImpl implements ChatQueryService { - @Autowired private ChatManageService chatManageService; - @Autowired private ChatLayerService chatLayerService; - @Autowired private SemanticLayerService semanticLayerService; - @Autowired private AgentService agentService; + @Autowired + private ChatManageService chatManageService; + @Autowired + private ChatLayerService chatLayerService; + @Autowired + private SemanticLayerService semanticLayerService; + @Autowired + private AgentService agentService; private List chatQueryParsers = ComponentFactory.getChatParsers(); private List chatQueryExecutors = ComponentFactory.getChatExecutors(); @@ -149,11 +153,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { chatParseReq.setUser(User.getFakeUser()); ParseResp parseResp = performParsing(chatParseReq); if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) { - log.debug( - "chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty", - chatId, - agentId, - queryText); + log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty", + chatId, agentId, queryText); return null; } ChatExecuteReq executeReq = new ChatExecuteReq(); @@ -184,9 +185,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { ExecuteContext executeContext = new ExecuteContext(); BeanMapper.mapper(chatExecuteReq, executeContext); - SemanticParseInfo parseInfo = - chatManageService.getParseInfo( - chatExecuteReq.getQueryId(), chatExecuteReq.getParseId()); + SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(), + chatExecuteReq.getParseId()); Agent agent = agentService.getAgent(chatExecuteReq.getAgentId()); executeContext.setAgent(agent); executeContext.setParseInfo(parseInfo); @@ -222,12 +222,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL()); } - private void handleLLMQueryMode( - ChatQueryDataReq chatQueryDataReq, - SemanticQuery semanticQuery, - DataSetSchema dataSetSchema, - User user) - throws Exception { + private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery, + DataSetSchema dataSetSchema, User user) throws Exception { SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); List fields = getFieldsFromSql(parseInfo); if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { @@ -245,16 +241,16 @@ public class ChatQueryServiceImpl implements ChatQueryService { } } - private void handleRuleQueryMode( - SemanticQuery semanticQuery, DataSetSchema dataSetSchema, User user) { + private void handleRuleQueryMode(SemanticQuery semanticQuery, DataSetSchema dataSetSchema, + User user) { log.info("rule begin replace metrics and revise filters!"); validFilter(semanticQuery.getParseInfo().getDimensionFilters()); validFilter(semanticQuery.getParseInfo().getMetricFilters()); semanticQuery.initS2Sql(dataSetSchema, user); } - private QueryResult executeQuery( - SemanticQuery semanticQuery, User user, DataSetSchema dataSetSchema) throws Exception { + private QueryResult executeQuery(SemanticQuery semanticQuery, User user, + DataSetSchema dataSetSchema) throws Exception { SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user); @@ -275,8 +271,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { return !oriFields.containsAll(metricNames); } - private String reviseCorrectS2SQL( - ChatQueryDataReq queryData, SemanticParseInfo parseInfo, DataSetSchema dataSetSchema) { + private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, + DataSetSchema dataSetSchema) { String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("correctorSql before replacing:{}", correctorSql); // get where filter and having filter @@ -286,21 +282,12 @@ public class ChatQueryServiceImpl implements ChatQueryService { // replace where filter List addWhereConditions = new ArrayList<>(); Set removeWhereFieldNames = - updateFilters( - whereExpressionList, - queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), - addWhereConditions); + updateFilters(whereExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addWhereConditions); Map> filedNameToValueMap = new HashMap<>(); - Set removeDataFieldNames = - updateDateInfo( - queryData, - parseInfo, - dataSetSchema, - filedNameToValueMap, - whereExpressionList, - addWhereConditions); + Set removeDataFieldNames = updateDateInfo(queryData, parseInfo, dataSetSchema, + filedNameToValueMap, whereExpressionList, addWhereConditions); removeWhereFieldNames.addAll(removeDataFieldNames); correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); @@ -311,11 +298,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { SqlSelectHelper.getHavingExpressions(correctorSql); List addHavingConditions = new ArrayList<>(); Set removeHavingFieldNames = - updateFilters( - havingExpressionList, - queryData.getDimensionFilters(), - parseInfo.getDimensionFilters(), - addHavingConditions); + updateFilters(havingExpressionList, queryData.getDimensionFilters(), + parseInfo.getDimensionFilters(), addHavingConditions); correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>()); correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); @@ -326,10 +310,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { } private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { - List oriMetrics = - parseInfo.getMetrics().stream() - .map(SchemaElement::getName) - .collect(Collectors.toList()); + List oriMetrics = parseInfo.getMetrics().stream().map(SchemaElement::getName) + .collect(Collectors.toList()); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); log.info("before replaceMetrics:{}", correctorSql); log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); @@ -362,20 +344,15 @@ public class ChatQueryServiceImpl implements ChatQueryService { return queryResult; } - private Set updateDateInfo( - ChatQueryDataReq queryData, - SemanticParseInfo parseInfo, - DataSetSchema dataSetSchema, - Map> filedNameToValueMap, - List fieldExpressionList, - List addConditions) { + private Set updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo, + DataSetSchema dataSetSchema, Map> filedNameToValueMap, + List fieldExpressionList, List addConditions) { Set removeFieldNames = new HashSet<>(); if (Objects.isNull(queryData.getDateInfo())) { return removeFieldNames; } if (queryData.getDateInfo().getUnit() > 1) { - queryData - .getDateInfo() + queryData.getDateInfo() .setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(0)); } @@ -386,16 +363,10 @@ public class ChatQueryServiceImpl implements ChatQueryService { // first remove,then add removeFieldNames.add(partitionDimension.getName()); GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); - addTimeFilters( - queryData.getDateInfo().getStartDate(), - greaterThanEquals, - addConditions, - partitionDimension); + addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals, + addConditions, partitionDimension); MinorThanEquals minorThanEquals = new MinorThanEquals(); - addTimeFilters( - queryData.getDateInfo().getEndDate(), - minorThanEquals, - addConditions, + addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions, partitionDimension); break; } @@ -403,8 +374,7 @@ public class ChatQueryServiceImpl implements ChatQueryService { for (FieldExpression fieldExpression : fieldExpressionList) { for (QueryFilter queryFilter : queryData.getDimensionFilters()) { if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE) - && FilterOperatorEnum.LIKE - .getValue() + && FilterOperatorEnum.LIKE.getValue() .equalsIgnoreCase(fieldExpression.getOperator())) { Map replaceMap = new HashMap<>(); String preValue = fieldExpression.getFieldValue().toString(); @@ -425,11 +395,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { return removeFieldNames; } - private void addTimeFilters( - String date, - T comparisonExpression, - List addConditions, - SchemaElement partitionDimension) { + private void addTimeFilters(String date, T comparisonExpression, + List addConditions, SchemaElement partitionDimension) { Column column = new Column(partitionDimension.getName()); StringValue stringValue = new StringValue(date); comparisonExpression.setLeftExpression(column); @@ -437,10 +404,8 @@ public class ChatQueryServiceImpl implements ChatQueryService { addConditions.add(comparisonExpression); } - private Set updateFilters( - List fieldExpressionList, - Set metricFilters, - Set contextMetricFilters, + private Set updateFilters(List fieldExpressionList, + Set metricFilters, Set contextMetricFilters, List addConditions) { Set removeFieldNames = new HashSet<>(); if (CollectionUtils.isEmpty(metricFilters)) { @@ -460,15 +425,13 @@ public class ChatQueryServiceImpl implements ChatQueryService { return removeFieldNames; } - private void handleFilter( - QueryFilter dslQueryFilter, - Set contextMetricFilters, + private void handleFilter(QueryFilter dslQueryFilter, Set contextMetricFilters, List addConditions) { FilterOperatorEnum operator = dslQueryFilter.getOperator(); if (operator == FilterOperatorEnum.IN) { - addWhereInFilters( - dslQueryFilter, new InExpression(), contextMetricFilters, addConditions); + addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters, + addConditions); } else { ComparisonOperator expression = FilterOperatorEnum.createExpression(operator); if (Objects.nonNull(expression)) { @@ -477,12 +440,9 @@ public class ChatQueryServiceImpl implements ChatQueryService { } } - // add in condition to sql where condition - private void addWhereInFilters( - QueryFilter dslQueryFilter, - InExpression inExpression, - Set contextMetricFilters, - List addConditions) { + // add in condition to sql where condition + private void addWhereInFilters(QueryFilter dslQueryFilter, InExpression inExpression, + Set contextMetricFilters, List addConditions) { Column column = new Column(dslQueryFilter.getName()); ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); List valueList = @@ -490,30 +450,24 @@ public class ChatQueryServiceImpl implements ChatQueryService { if (CollectionUtils.isEmpty(valueList)) { return; } - valueList.stream() - .forEach( - o -> { - StringValue stringValue = new StringValue(o); - parenthesedExpressionList.add(stringValue); - }); + valueList.stream().forEach(o -> { + StringValue stringValue = new StringValue(o); + parenthesedExpressionList.add(stringValue); + }); inExpression.setLeftExpression(column); inExpression.setRightExpression(parenthesedExpressionList); addConditions.add(inExpression); - contextMetricFilters.stream() - .forEach( - o -> { - if (o.getName().equals(dslQueryFilter.getName())) { - o.setValue(dslQueryFilter.getValue()); - o.setOperator(dslQueryFilter.getOperator()); - } - }); + contextMetricFilters.stream().forEach(o -> { + if (o.getName().equals(dslQueryFilter.getName())) { + o.setValue(dslQueryFilter.getValue()); + o.setOperator(dslQueryFilter.getOperator()); + } + }); } // add where filter - private void addWhereFilters( - QueryFilter dslQueryFilter, - ComparisonOperator comparisonExpression, - Set contextMetricFilters, + private void addWhereFilters(QueryFilter dslQueryFilter, + ComparisonOperator comparisonExpression, Set contextMetricFilters, List addConditions) { String columnName = dslQueryFilter.getName(); if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { @@ -533,18 +487,16 @@ public class ChatQueryServiceImpl implements ChatQueryService { comparisonExpression.setRightExpression(stringValue); } addConditions.add(comparisonExpression); - contextMetricFilters.stream() - .forEach( - o -> { - if (o.getName().equals(dslQueryFilter.getName())) { - o.setValue(dslQueryFilter.getValue()); - o.setOperator(dslQueryFilter.getOperator()); - } - }); + contextMetricFilters.stream().forEach(o -> { + if (o.getName().equals(dslQueryFilter.getName())) { + o.setValue(dslQueryFilter.getValue()); + o.setOperator(dslQueryFilter.getOperator()); + } + }); } - private SemanticParseInfo mergeParseInfo( - SemanticParseInfo parseInfo, ChatQueryDataReq queryData) { + private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo, + ChatQueryDataReq queryData) { if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { return parseInfo; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ConfigServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ConfigServiceImpl.java index 0fd2ed66d..7ee668268 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ConfigServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/ConfigServiceImpl.java @@ -51,10 +51,8 @@ public class ConfigServiceImpl implements ConfigService { private final ChatConfigHelper chatConfigHelper; private final SemanticLayerService semanticLayerService; - public ConfigServiceImpl( - ChatConfigRepository chatConfigRepository, - ChatConfigHelper chatConfigHelper, - SemanticLayerService semanticLayerService) { + public ConfigServiceImpl(ChatConfigRepository chatConfigRepository, + ChatConfigHelper chatConfigHelper, SemanticLayerService semanticLayerService) { this.chatConfigRepository = chatConfigRepository; this.chatConfigHelper = chatConfigHelper; this.semanticLayerService = semanticLayerService; @@ -80,9 +78,8 @@ public class ConfigServiceImpl implements ConfigService { @Override public Long editConfig(ChatConfigEditReqReq configEditCmd, User user) { log.info("[edit model extend] object:{}", JsonUtil.toString(configEditCmd, true)); - if (Objects.isNull(configEditCmd) - || Objects.isNull(configEditCmd.getId()) - && Objects.isNull(configEditCmd.getModelId())) { + if (Objects.isNull(configEditCmd) || Objects.isNull(configEditCmd.getId()) + && Objects.isNull(configEditCmd.getModelId())) { throw new RuntimeException( "editConfig, id and modelId are not allowed to be empty at the same time"); } @@ -107,13 +104,13 @@ public class ConfigServiceImpl implements ConfigService { List blackDimIdList = new ArrayList<>(); if (Objects.nonNull(chatConfig.getChatAggConfig()) && Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) { - blackDimIdList.addAll( - chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList()); + blackDimIdList + .addAll(chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList()); } if (Objects.nonNull(chatConfig.getChatDetailConfig()) && Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) { - blackDimIdList.addAll( - chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList()); + blackDimIdList + .addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList()); } List filterDimIdList = blackDimIdList.stream().distinct().collect(Collectors.toList()); @@ -121,8 +118,8 @@ public class ConfigServiceImpl implements ConfigService { List blackMetricIdList = new ArrayList<>(); if (Objects.nonNull(chatConfig.getChatAggConfig()) && Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) { - blackMetricIdList.addAll( - chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList()); + blackMetricIdList + .addAll(chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList()); } if (Objects.nonNull(chatConfig.getChatDetailConfig()) && Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) { @@ -138,20 +135,16 @@ public class ConfigServiceImpl implements ConfigService { if (!CollectionUtils.isEmpty(blackDimIdList)) { List dimensionRespList = semanticLayerService.getDimensions(metaFilter); List blackDimNameList = - dimensionRespList.stream() - .filter(o -> filterDimIdList.contains(o.getId())) - .map(SchemaItem::getName) - .collect(Collectors.toList()); + dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId())) + .map(SchemaItem::getName).collect(Collectors.toList()); itemNameVisibility.setBlackDimNameList(blackDimNameList); } if (!CollectionUtils.isEmpty(blackMetricIdList)) { List metricRespList = semanticLayerService.getMetrics(metaFilter); List blackMetricList = - metricRespList.stream() - .filter(o -> filterMetricIdList.contains(o.getId())) - .map(SchemaItem::getName) - .collect(Collectors.toList()); + metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId())) + .map(SchemaItem::getName).collect(Collectors.toList()); itemNameVisibility.setBlackMetricNameList(blackMetricList); } return itemNameVisibility; @@ -169,8 +162,8 @@ public class ConfigServiceImpl implements ConfigService { return chatConfigRepository.getConfigByModelId(modelId); } - private ItemVisibilityInfo fetchVisibilityDescByConfig( - ItemVisibility visibility, DataSetSchema modelSchema) { + private ItemVisibilityInfo fetchVisibilityDescByConfig(ItemVisibility visibility, + DataSetSchema modelSchema) { ItemVisibilityInfo itemVisibilityDesc = new ItemVisibilityInfo(); List dimIdAllList = chatConfigHelper.generateAllDimIdList(modelSchema); @@ -186,17 +179,12 @@ public class ConfigServiceImpl implements ConfigService { blackMetricIdList.addAll(visibility.getBlackMetricIdList()); } } - List whiteMetricIdList = - metricIdAllList.stream() - .filter( - id -> - !blackMetricIdList.contains(id) - && metricIdAllList.contains(id)) - .collect(Collectors.toList()); - List whiteDimIdList = - dimIdAllList.stream() - .filter(id -> !blackDimIdList.contains(id) && dimIdAllList.contains(id)) - .collect(Collectors.toList()); + List whiteMetricIdList = metricIdAllList.stream() + .filter(id -> !blackMetricIdList.contains(id) && metricIdAllList.contains(id)) + .collect(Collectors.toList()); + List whiteDimIdList = dimIdAllList.stream() + .filter(id -> !blackDimIdList.contains(id) && dimIdAllList.contains(id)) + .collect(Collectors.toList()); itemVisibilityDesc.setBlackDimIdList(blackDimIdList); itemVisibilityDesc.setBlackMetricIdList(blackMetricIdList); @@ -232,10 +220,8 @@ public class ConfigServiceImpl implements ConfigService { return chatConfigRich; } - private ChatDetailRichConfigResp fillChatDetailRichConfig( - DataSetSchema modelSchema, - ChatConfigRichResp chatConfigRich, - ChatConfigResp chatConfigResp) { + private ChatDetailRichConfigResp fillChatDetailRichConfig(DataSetSchema modelSchema, + ChatConfigRichResp chatConfigRich, ChatConfigResp chatConfigResp) { if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatDetailConfig())) { return null; @@ -248,9 +234,8 @@ public class ConfigServiceImpl implements ConfigService { detailRichConfig.setKnowledgeInfos( fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), modelSchema)); detailRichConfig.setGlobalKnowledgeConfig(chatDetailConfig.getGlobalKnowledgeConfig()); - detailRichConfig.setChatDefaultConfig( - fetchDefaultConfig( - chatDetailConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo)); + detailRichConfig.setChatDefaultConfig(fetchDefaultConfig( + chatDetailConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo)); return detailRichConfig; } @@ -261,18 +246,15 @@ public class ConfigServiceImpl implements ConfigService { return entityRichInfo; } BeanUtils.copyProperties(entity, entityRichInfo); - Map dimIdAndRespPair = - modelSchema.getDimensions().stream() - .collect( - Collectors.toMap( - SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); + Map dimIdAndRespPair = modelSchema.getDimensions().stream().collect( + Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); entityRichInfo.setDimItem(dimIdAndRespPair.get(entity.getEntityId())); return entityRichInfo; } - private ChatAggRichConfigResp fillChatAggRichConfig( - DataSetSchema modelSchema, ChatConfigResp chatConfigResp) { + private ChatAggRichConfigResp fillChatAggRichConfig(DataSetSchema modelSchema, + ChatConfigResp chatConfigResp) { if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) { return null; } @@ -284,72 +266,53 @@ public class ConfigServiceImpl implements ConfigService { chatAggRichConfig.setKnowledgeInfos( fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), modelSchema)); chatAggRichConfig.setGlobalKnowledgeConfig(chatAggConfig.getGlobalKnowledgeConfig()); - chatAggRichConfig.setChatDefaultConfig( - fetchDefaultConfig( - chatAggConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo)); + chatAggRichConfig.setChatDefaultConfig(fetchDefaultConfig( + chatAggConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo)); return chatAggRichConfig; } - private ChatDefaultRichConfigResp fetchDefaultConfig( - ChatDefaultConfigReq chatDefaultConfig, - DataSetSchema modelSchema, - ItemVisibilityInfo itemVisibilityInfo) { + private ChatDefaultRichConfigResp fetchDefaultConfig(ChatDefaultConfigReq chatDefaultConfig, + DataSetSchema modelSchema, ItemVisibilityInfo itemVisibilityInfo) { ChatDefaultRichConfigResp defaultRichConfig = new ChatDefaultRichConfigResp(); if (Objects.isNull(chatDefaultConfig)) { return defaultRichConfig; } BeanUtils.copyProperties(chatDefaultConfig, defaultRichConfig); - Map dimIdAndRespPair = - modelSchema.getDimensions().stream() - .collect( - Collectors.toMap( - SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); + Map dimIdAndRespPair = modelSchema.getDimensions().stream().collect( + Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); - Map metricIdAndRespPair = - modelSchema.getMetrics().stream() - .collect( - Collectors.toMap( - SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); + Map metricIdAndRespPair = modelSchema.getMetrics().stream().collect( + Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); List dimensions = new ArrayList<>(); List metrics = new ArrayList<>(); if (!CollectionUtils.isEmpty(chatDefaultConfig.getDimensionIds())) { chatDefaultConfig.getDimensionIds().stream() - .filter( - dimId -> - dimIdAndRespPair.containsKey(dimId) - && itemVisibilityInfo - .getWhiteDimIdList() - .contains(dimId)) - .forEach( - dimId -> { - SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId); - if (Objects.nonNull(dimSchemaResp)) { - SchemaElement dimSchema = new SchemaElement(); - BeanUtils.copyProperties(dimSchemaResp, dimSchema); - dimensions.add(dimSchema); - } - }); + .filter(dimId -> dimIdAndRespPair.containsKey(dimId) + && itemVisibilityInfo.getWhiteDimIdList().contains(dimId)) + .forEach(dimId -> { + SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId); + if (Objects.nonNull(dimSchemaResp)) { + SchemaElement dimSchema = new SchemaElement(); + BeanUtils.copyProperties(dimSchemaResp, dimSchema); + dimensions.add(dimSchema); + } + }); } if (!CollectionUtils.isEmpty(chatDefaultConfig.getMetricIds())) { chatDefaultConfig.getMetricIds().stream() - .filter( - metricId -> - metricIdAndRespPair.containsKey(metricId) - && itemVisibilityInfo - .getWhiteMetricIdList() - .contains(metricId)) - .forEach( - metricId -> { - SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId); - if (Objects.nonNull(metricSchemaResp)) { - SchemaElement metricSchema = new SchemaElement(); - BeanUtils.copyProperties(metricSchemaResp, metricSchema); - metrics.add(metricSchema); - } - }); + .filter(metricId -> metricIdAndRespPair.containsKey(metricId) + && itemVisibilityInfo.getWhiteMetricIdList().contains(metricId)) + .forEach(metricId -> { + SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId); + if (Objects.nonNull(metricSchemaResp)) { + SchemaElement metricSchema = new SchemaElement(); + BeanUtils.copyProperties(metricSchemaResp, metricSchema); + metrics.add(metricSchema); + } + }); } defaultRichConfig.setDimensions(dimensions); @@ -357,27 +320,21 @@ public class ConfigServiceImpl implements ConfigService { return defaultRichConfig; } - private List fillKnowledgeBizName( - List knowledgeInfos, DataSetSchema modelSchema) { + private List fillKnowledgeBizName(List knowledgeInfos, + DataSetSchema modelSchema) { if (CollectionUtils.isEmpty(knowledgeInfos)) { return new ArrayList<>(); } - Map dimIdAndRespPair = - modelSchema.getDimensions().stream() - .collect( - Collectors.toMap( - SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); - knowledgeInfos.stream() - .forEach( - knowledgeInfo -> { - if (Objects.nonNull(knowledgeInfo)) { - SchemaElement dimSchemaResp = - dimIdAndRespPair.get(knowledgeInfo.getItemId()); - if (Objects.nonNull(dimSchemaResp)) { - knowledgeInfo.setBizName(dimSchemaResp.getBizName()); - } - } - }); + Map dimIdAndRespPair = modelSchema.getDimensions().stream().collect( + Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); + knowledgeInfos.stream().forEach(knowledgeInfo -> { + if (Objects.nonNull(knowledgeInfo)) { + SchemaElement dimSchemaResp = dimIdAndRespPair.get(knowledgeInfo.getItemId()); + if (Objects.nonNull(dimSchemaResp)) { + knowledgeInfo.setBizName(dimSchemaResp.getBizName()); + } + } + }); return knowledgeInfos; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java index 1913f1486..d3fa92a8c 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/MemoryServiceImpl.java @@ -25,11 +25,14 @@ import java.util.List; @Service public class MemoryServiceImpl implements MemoryService { - @Autowired private ChatMemoryRepository chatMemoryRepository; + @Autowired + private ChatMemoryRepository chatMemoryRepository; - @Autowired private ExemplarService exemplarService; + @Autowired + private ExemplarService exemplarService; - @Autowired private EmbeddingConfig embeddingConfig; + @Autowired + private EmbeddingConfig embeddingConfig; @Override public void createMemory(ChatMemoryDO memory) { @@ -85,20 +88,18 @@ public class MemoryServiceImpl implements MemoryService { queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus()); } if (chatMemoryFilter.getHumanReviewRet() != null) { - queryWrapper - .lambda() - .eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet()); + queryWrapper.lambda().eq(ChatMemoryDO::getHumanReviewRet, + chatMemoryFilter.getHumanReviewRet()); } if (chatMemoryFilter.getLlmReviewRet() != null) { - queryWrapper - .lambda() - .eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet()); + queryWrapper.lambda().eq(ChatMemoryDO::getLlmReviewRet, + chatMemoryFilter.getLlmReviewRet()); } if (StringUtils.isBlank(chatMemoryFilter.getOrderCondition())) { queryWrapper.orderByDesc("id"); } else { - queryWrapper.orderBy( - true, chatMemoryFilter.isAsc(), chatMemoryFilter.getOrderCondition()); + queryWrapper.orderBy(true, chatMemoryFilter.isAsc(), + chatMemoryFilter.getOrderCondition()); } return chatMemoryRepository.getMemories(queryWrapper); } @@ -106,9 +107,7 @@ public class MemoryServiceImpl implements MemoryService { @Override public List getMemoriesForLlmReview() { QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper - .lambda() - .eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING) + queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING) .isNull(ChatMemoryDO::getLlmReviewRet); return chatMemoryRepository.getMemories(queryWrapper); } @@ -116,26 +115,18 @@ public class MemoryServiceImpl implements MemoryService { @Override public void enableMemory(ChatMemoryDO memory) { memory.setStatus(MemoryStatus.ENABLED); - exemplarService.storeExemplar( - embeddingConfig.getMemoryCollectionName(memory.getAgentId()), - Text2SQLExemplar.builder() - .question(memory.getQuestion()) - .sideInfo(memory.getSideInfo()) - .dbSchema(memory.getDbSchema()) - .sql(memory.getS2sql()) - .build()); + exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), + Text2SQLExemplar.builder().question(memory.getQuestion()) + .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema()) + .sql(memory.getS2sql()).build()); } @Override public void disableMemory(ChatMemoryDO memory) { memory.setStatus(MemoryStatus.DISABLED); - exemplarService.removeExemplar( - embeddingConfig.getMemoryCollectionName(memory.getAgentId()), - Text2SQLExemplar.builder() - .question(memory.getQuestion()) - .sideInfo(memory.getSideInfo()) - .dbSchema(memory.getDbSchema()) - .sql(memory.getS2sql()) - .build()); + exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()), + Text2SQLExemplar.builder().question(memory.getQuestion()) + .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema()) + .sql(memory.getS2sql()).build()); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java index d642ec034..fb3f900c6 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/PluginServiceImpl.java @@ -36,8 +36,8 @@ public class PluginServiceImpl implements PluginService { private ApplicationEventPublisher publisher; - public PluginServiceImpl( - PluginRepository pluginRepository, ApplicationEventPublisher publisher) { + public PluginServiceImpl(PluginRepository pluginRepository, + ApplicationEventPublisher publisher) { this.pluginRepository = pluginRepository; this.publisher = publisher; } @@ -110,18 +110,11 @@ public class PluginServiceImpl implements PluginService { } List pluginDOS = pluginRepository.query(queryWrapper); if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) { - pluginDOS = - pluginDOS.stream() - .filter( - pluginDO -> - pluginDO.getPattern() - .contains(pluginQueryReq.getPattern()) - || (pluginDO.getName() != null - && pluginDO.getName() - .contains( - pluginQueryReq - .getPattern()))) - .collect(Collectors.toList()); + pluginDOS = pluginDOS.stream() + .filter(pluginDO -> pluginDO.getPattern().contains(pluginQueryReq.getPattern()) + || (pluginDO.getName() != null + && pluginDO.getName().contains(pluginQueryReq.getPattern()))) + .collect(Collectors.toList()); } return convertList(pluginDOS); } @@ -129,16 +122,13 @@ public class PluginServiceImpl implements PluginService { @Override public Optional getPluginByName(String name) { log.info("name:{}", name); - return getPluginList().stream() - .filter( - plugin -> { - PluginParseConfig functionCallConfig = getPluginParseConfig(plugin); - if (functionCallConfig == null) { - return false; - } - return functionCallConfig.getName().equalsIgnoreCase(name); - }) - .findFirst(); + return getPluginList().stream().filter(plugin -> { + PluginParseConfig functionCallConfig = getPluginParseConfig(plugin); + if (functionCallConfig == null) { + return false; + } + return functionCallConfig.getName().equalsIgnoreCase(name); + }).findFirst(); } private PluginParseConfig getPluginParseConfig(ChatPlugin plugin) { @@ -166,26 +156,17 @@ public class PluginServiceImpl implements PluginService { public Map getNameToPlugin() { List pluginList = getPluginList(); - return pluginList.stream() - .filter( - plugin -> { - PluginParseConfig functionCallConfig = getPluginParseConfig(plugin); - if (functionCallConfig == null) { - return false; - } - return true; - }) - .collect( - Collectors.toMap( - a -> { - PluginParseConfig functionCallConfig = - JsonUtil.toObject( - a.getParseModeConfig(), - PluginParseConfig.class); - return functionCallConfig.getName(); - }, - a -> a, - (k1, k2) -> k1)); + return pluginList.stream().filter(plugin -> { + PluginParseConfig functionCallConfig = getPluginParseConfig(plugin); + if (functionCallConfig == null) { + return false; + } + return true; + }).collect(Collectors.toMap(a -> { + PluginParseConfig functionCallConfig = + JsonUtil.toObject(a.getParseModeConfig(), PluginParseConfig.class); + return functionCallConfig.getName(); + }, a -> a, (k1, k2) -> k1)); } // todo @@ -197,10 +178,8 @@ public class PluginServiceImpl implements PluginService { ChatPlugin plugin = new ChatPlugin(); BeanUtils.copyProperties(pluginDO, plugin); if (StringUtils.isNotBlank(pluginDO.getDataSet())) { - plugin.setDataSetList( - Arrays.stream(pluginDO.getDataSet().split(",")) - .map(Long::parseLong) - .collect(Collectors.toList())); + plugin.setDataSetList(Arrays.stream(pluginDO.getDataSet().split(",")) + .map(Long::parseLong).collect(Collectors.toList())); } return plugin; } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/StatisticsServiceImpl.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/StatisticsServiceImpl.java index 238020a01..2fec11e06 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/StatisticsServiceImpl.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/service/impl/StatisticsServiceImpl.java @@ -14,7 +14,8 @@ import java.util.List; @Slf4j public class StatisticsServiceImpl implements StatisticsService { - @Autowired private StatisticsMapper statisticsMapper; + @Autowired + private StatisticsMapper statisticsMapper; @Async @Override diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ChatConfigHelper.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ChatConfigHelper.java index 1e0024717..72cdd4938 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ChatConfigHelper.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ChatConfigHelper.java @@ -37,10 +37,8 @@ public class ChatConfigHelper { ChatConfig chatConfig = new ChatConfig(); BeanUtils.copyProperties(extendBaseCmd, chatConfig); RecordInfo recordInfo = new RecordInfo(); - String creator = - (Objects.isNull(user) || StringUtils.isEmpty(user.getName())) - ? ADMIN_LOWER - : user.getName(); + String creator = (Objects.isNull(user) || StringUtils.isEmpty(user.getName())) ? ADMIN_LOWER + : user.getName(); recordInfo.createdBy(creator); chatConfig.setRecordInfo(recordInfo); chatConfig.setStatus(StatusEnum.ONLINE); @@ -52,10 +50,9 @@ public class ChatConfigHelper { BeanUtils.copyProperties(extendEditCmd, chatConfig); RecordInfo recordInfo = new RecordInfo(); - String user = - (Objects.isNull(facadeUser) || StringUtils.isEmpty(facadeUser.getName())) - ? ADMIN_LOWER - : facadeUser.getName(); + String user = (Objects.isNull(facadeUser) || StringUtils.isEmpty(facadeUser.getName())) + ? ADMIN_LOWER + : facadeUser.getName(); recordInfo.updatedBy(user); chatConfig.setRecordInfo(recordInfo); return chatConfig; @@ -65,9 +62,8 @@ public class ChatConfigHelper { if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getDimensions())) { return new ArrayList<>(); } - Map> dimIdAndDescPair = - modelSchema.getDimensions().stream() - .collect(Collectors.groupingBy(SchemaElement::getId)); + Map> dimIdAndDescPair = modelSchema.getDimensions().stream() + .collect(Collectors.groupingBy(SchemaElement::getId)); return new ArrayList<>(dimIdAndDescPair.keySet()); } @@ -75,9 +71,8 @@ public class ChatConfigHelper { if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getMetrics())) { return new ArrayList<>(); } - Map> metricIdAndDescPair = - modelSchema.getMetrics().stream() - .collect(Collectors.groupingBy(SchemaElement::getId)); + Map> metricIdAndDescPair = modelSchema.getMetrics().stream() + .collect(Collectors.groupingBy(SchemaElement::getId)); return new ArrayList<>(metricIdAndDescPair.keySet()); } @@ -87,8 +82,8 @@ public class ChatConfigHelper { chatConfigDO.setChatAggConfig(JsonUtil.toString(chatConfig.getChatAggConfig())); chatConfigDO.setChatDetailConfig(JsonUtil.toString(chatConfig.getChatDetailConfig())); - chatConfigDO.setRecommendedQuestions( - JsonUtil.toString(chatConfig.getRecommendedQuestions())); + chatConfigDO + .setRecommendedQuestions(JsonUtil.toString(chatConfig.getRecommendedQuestions())); if (Objects.isNull(chatConfig.getStatus())) { chatConfigDO.setStatus(null); @@ -118,9 +113,8 @@ public class ChatConfigHelper { JsonUtil.toObject(chatConfigDO.getChatDetailConfig(), ChatDetailConfigReq.class)); chatConfigDescriptor.setChatAggConfig( JsonUtil.toObject(chatConfigDO.getChatAggConfig(), ChatAggConfigReq.class)); - chatConfigDescriptor.setRecommendedQuestions( - JsonUtil.toList( - chatConfigDO.getRecommendedQuestions(), RecommendedQuestionReq.class)); + chatConfigDescriptor.setRecommendedQuestions(JsonUtil + .toList(chatConfigDO.getRecommendedQuestions(), RecommendedQuestionReq.class)); chatConfigDescriptor.setStatusEnum(StatusEnum.of(chatConfigDO.getStatus())); chatConfigDescriptor.setCreatedBy(chatConfigDO.getCreatedBy()); diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java index fc123e0d7..a825c1c7d 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ComponentFactory.java @@ -51,15 +51,13 @@ public class ComponentFactory { } private static List init(Class factoryType, List list) { - list.addAll( - SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader())); + list.addAll(SpringFactoriesLoader.loadFactories(factoryType, + Thread.currentThread().getContextClassLoader())); return list; } private static T init(Class factoryType) { - return SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader()) - .get(0); + return SpringFactoriesLoader + .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } } diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ResultFormatter.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ResultFormatter.java index 168a9084a..e15c414df 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ResultFormatter.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/util/ResultFormatter.java @@ -8,8 +8,8 @@ import java.util.Map; public class ResultFormatter { - public static String transform2TextNew( - List queryColumns, List> queryResults) { + public static String transform2TextNew(List queryColumns, + List> queryResults) { if (CollectionUtils.isEmpty(queryColumns)) { return ""; } diff --git a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java index 2a47bff7e..f9887a2df 100644 --- a/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java +++ b/common/src/main/java/com/hankcs/hanlp/LoadRemoveService.java @@ -24,13 +24,12 @@ public class LoadRemoveService { } List resultList = new ArrayList<>(value); if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) { - resultList.removeIf( - nature -> { - if (Objects.isNull(nature)) { - return false; - } - return nature.startsWith(mapperRemoveNaturePrefix); - }); + resultList.removeIf(nature -> { + if (Objects.isNull(nature)) { + return false; + } + return nature.startsWith(mapperRemoveNaturePrefix); + }); } return resultList; } diff --git a/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java b/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java index 7ee3ab3e0..c8220e1b6 100644 --- a/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java +++ b/common/src/main/java/com/hankcs/hanlp/collection/trie/bintrie/BaseNode.java @@ -253,19 +253,8 @@ public abstract class BaseNode implements Comparable { @Override public String toString() { - return "BaseNode{" - + "child=" - + Arrays.toString(child) - + ", status=" - + status - + ", c=" - + c - + ", value=" - + value - + ", prefix='" - + prefix - + '\'' - + '}'; + return "BaseNode{" + "child=" + Arrays.toString(child) + ", status=" + status + ", c=" + c + + ", value=" + value + ", prefix='" + prefix + '\'' + '}'; } public void walkNode(Set> entrySet) { diff --git a/common/src/main/java/com/hankcs/hanlp/dictionary/CoreDictionary.java b/common/src/main/java/com/hankcs/hanlp/dictionary/CoreDictionary.java index 40d0f3bab..3522c29a1 100644 --- a/common/src/main/java/com/hankcs/hanlp/dictionary/CoreDictionary.java +++ b/common/src/main/java/com/hankcs/hanlp/dictionary/CoreDictionary.java @@ -34,13 +34,8 @@ public class CoreDictionary { if (!load(PATH)) { throw new IllegalArgumentException("核心词典" + PATH + "加载失败"); } else { - Predefine.logger.info( - PATH - + "加载成功," - + trie.size() - + "个词条,耗时" - + (System.currentTimeMillis() - start) - + "ms"); + Predefine.logger.info(PATH + "加载成功," + trie.size() + "个词条,耗时" + + (System.currentTimeMillis() - start) + "ms"); } } @@ -77,22 +72,14 @@ public class CoreDictionary { map.put(param[0], attribute); totalFrequency += attribute.totalFrequency; } - Predefine.logger.info( - "核心词典读入词条" - + map.size() - + " 全部频次" - + totalFrequency - + ",耗时" - + (System.currentTimeMillis() - start) - + "ms"); + Predefine.logger.info("核心词典读入词条" + map.size() + " 全部频次" + totalFrequency + ",耗时" + + (System.currentTimeMillis() - start) + "ms"); br.close(); trie.build(map); Predefine.logger.info("核心词典加载成功:" + trie.size() + "个词条,下面将写入缓存……"); try { - DataOutputStream out = - new DataOutputStream( - new BufferedOutputStream( - IOUtil.newOutputStream(path + Predefine.BIN_EXT))); + DataOutputStream out = new DataOutputStream( + new BufferedOutputStream(IOUtil.newOutputStream(path + Predefine.BIN_EXT))); Collection attributeList = map.values(); out.writeInt(attributeList.size()); for (Attribute attribute : attributeList) { @@ -278,11 +265,8 @@ public class CoreDictionary { } return attribute; } catch (Exception e) { - Predefine.logger.warning( - "使用字符串" - + natureWithFrequency - + "创建词条属性失败!" - + TextUtility.exceptionToString(e)); + Predefine.logger.warning("使用字符串" + natureWithFrequency + "创建词条属性失败!" + + TextUtility.exceptionToString(e)); return null; } } @@ -409,9 +393,7 @@ public class CoreDictionary { if (originals == null || originals.length == 0) { return null; } - return Arrays.stream(originals) - .filter(o -> o != null) - .distinct() + return Arrays.stream(originals).filter(o -> o != null).distinct() .collect(Collectors.toList()); } } diff --git a/common/src/main/java/com/hankcs/hanlp/seg/WordBasedSegment.java b/common/src/main/java/com/hankcs/hanlp/seg/WordBasedSegment.java index a5c16a92e..ff636b623 100644 --- a/common/src/main/java/com/hankcs/hanlp/seg/WordBasedSegment.java +++ b/common/src/main/java/com/hankcs/hanlp/seg/WordBasedSegment.java @@ -47,8 +47,7 @@ public abstract class WordBasedSegment extends Segment { } vertex = (Vertex) var1.next(); - } while (!vertex.realWord.equals("--") - && !vertex.realWord.equals("—") + } while (!vertex.realWord.equals("--") && !vertex.realWord.equals("—") && !vertex.realWord.equals("-")); vertex.confirmNature(Nature.w); @@ -66,8 +65,7 @@ public abstract class WordBasedSegment extends Segment { if (currentNature == Nature.nx && (next.hasNature(Nature.q) || next.hasNature(Nature.n))) { String[] param = current.realWord.split("-", 1); - if (param.length == 2 - && TextUtility.isAllNum(param[0]) + if (param.length == 2 && TextUtility.isAllNum(param[0]) && TextUtility.isAllNum(param[1])) { current = current.copy(); current.realWord = param[0]; @@ -112,10 +110,8 @@ public abstract class WordBasedSegment extends Segment { current.confirmNature(Nature.m, true); } else if (current.realWord.length() > 1) { char last = current.realWord.charAt(current.realWord.length() - 1); - current = - Vertex.newNumberInstance( - current.realWord.substring( - 0, current.realWord.length() - 1)); + current = Vertex.newNumberInstance( + current.realWord.substring(0, current.realWord.length() - 1)); listIterator.previous(); listIterator.previous(); listIterator.set(current); @@ -162,9 +158,7 @@ public abstract class WordBasedSegment extends Segment { charTypeArray[i] = CharType.get(c); if (c == '.' && i < charArray.length - 1 && CharType.get(charArray[i + 1]) == 9) { charTypeArray[i] = 9; - } else if (c == '.' - && i < charArray.length - 1 - && charArray[i + 1] >= '0' + } else if (c == '.' && i < charArray.length - 1 && charArray[i + 1] >= '0' && charArray[i + 1] <= '9') { charTypeArray[i] = 5; } else if (charTypeArray[i] == 8) { @@ -227,7 +221,7 @@ public abstract class WordBasedSegment extends Segment { while (listIterator.hasNext()) { next = (Vertex) listIterator.next(); if (!TextUtility.isAllNum(current.realWord) - && !TextUtility.isAllChineseNum(current.realWord) + && !TextUtility.isAllChineseNum(current.realWord) || !TextUtility.isAllNum(next.realWord) && !TextUtility.isAllChineseNum(next.realWord)) { current = next; @@ -252,21 +246,16 @@ public abstract class WordBasedSegment extends Segment { DoubleArrayTrie.Searcher searcher = CoreDictionary.trie.getSearcher(charArray, 0); while (searcher.next()) { - wordNetStorage.add( - searcher.begin + 1, - new Vertex( - new String(charArray, searcher.begin, searcher.length), - (CoreDictionary.Attribute) searcher.value, - searcher.index)); + wordNetStorage.add(searcher.begin + 1, + new Vertex(new String(charArray, searcher.begin, searcher.length), + (CoreDictionary.Attribute) searcher.value, searcher.index)); } if (this.config.forceCustomDictionary) { - this.customDictionary.parseText( - charArray, + this.customDictionary.parseText(charArray, new AhoCorasickDoubleArrayTrie.IHit() { public void hit(int begin, int end, CoreDictionary.Attribute value) { - wordNetStorage.add( - begin + 1, + wordNetStorage.add(begin + 1, new Vertex(new String(charArray, begin, end - begin), value)); } }); @@ -279,11 +268,9 @@ public abstract class WordBasedSegment extends Segment { while (i < vertexes.length) { if (vertexes[i].isEmpty()) { int j; - for (j = i + 1; - j < vertexes.length - 1 - && (vertexes[j].isEmpty() - || CharType.get(charArray[j - 1]) == 11); - ++j) {} + for (j = i + 1; j < vertexes.length - 1 && (vertexes[j].isEmpty() + || CharType.get(charArray[j - 1]) == 11); ++j) { + } wordNetStorage.add(i, Segment.quickAtomSegment(charArray, i - 1, j - 1)); i = j; @@ -310,10 +297,8 @@ public abstract class WordBasedSegment extends Segment { addTerms(termList, vertex, line - 1); termMain.offset = line - 1; if (vertex.realWord.length() > 2) { - label43: - for (int currentLine = line; - currentLine < line + vertex.realWord.length(); - ++currentLine) { + label43: for (int currentLine = line; currentLine < line + + vertex.realWord.length(); ++currentLine) { Iterator iterator = wordNetAll.descendingIterator(currentLine); while (true) { @@ -327,8 +312,8 @@ public abstract class WordBasedSegment extends Segment { && smallVertex.realWord.length() < this.config.indexMode); if (smallVertex != vertex - && currentLine + smallVertex.realWord.length() - <= line + vertex.realWord.length()) { + && currentLine + smallVertex.realWord.length() <= line + + vertex.realWord.length()) { listIterator.add(smallVertex); // Term termSub = convert(smallVertex); // termSub.offset = currentLine - 1; @@ -346,8 +331,8 @@ public abstract class WordBasedSegment extends Segment { } protected static void speechTagging(List vertexList) { - Viterbi.compute( - vertexList, CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary); + Viterbi.compute(vertexList, + CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary); } protected void addTerms(List terms, Vertex vertex, int offset) { diff --git a/common/src/main/java/com/hankcs/hanlp/seg/common/Term.java b/common/src/main/java/com/hankcs/hanlp/seg/common/Term.java index f98acc642..8bfd2f5e0 100644 --- a/common/src/main/java/com/hankcs/hanlp/seg/common/Term.java +++ b/common/src/main/java/com/hankcs/hanlp/seg/common/Term.java @@ -42,19 +42,13 @@ public class Term { } // todo opt /* - String wordOri = word.toLowerCase(); - CoreDictionary.Attribute attribute = getDynamicCustomDictionary().get(wordOri); - if (attribute == null) { - attribute = CoreDictionary.get(wordOri); - if (attribute == null) { - attribute = CustomDictionary.get(wordOri); - } - } - if (attribute != null && nature != null && attribute.hasNature(nature)) { - return attribute.getNatureFrequency(nature); - } - return attribute == null ? 0 : attribute.totalFrequency; - */ + * String wordOri = word.toLowerCase(); CoreDictionary.Attribute attribute = + * getDynamicCustomDictionary().get(wordOri); if (attribute == null) { attribute = + * CoreDictionary.get(wordOri); if (attribute == null) { attribute = + * CustomDictionary.get(wordOri); } } if (attribute != null && nature != null && + * attribute.hasNature(nature)) { return attribute.getNatureFrequency(nature); } return + * attribute == null ? 0 : attribute.totalFrequency; + */ return 0; } diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java b/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java index b055efabf..749db0d8f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/Configuration.java @@ -51,19 +51,18 @@ public class Configuration { public static SqlValidator.Config getValidatorConfig(EngineType engineType) { SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType); - return SqlValidator.Config.DEFAULT - .withConformance(sqlDialect.getConformance()) + return SqlValidator.Config.DEFAULT.withConformance(sqlDialect.getConformance()) .withDefaultNullCollation(config.defaultNullCollation()) .withLenientOperatorLookup(true); } static { - configProperties.put( - CalciteConnectionProperty.CASE_SENSITIVE.camelName(), Boolean.TRUE.toString()); - configProperties.put( - CalciteConnectionProperty.UNQUOTED_CASING.camelName(), Casing.UNCHANGED.toString()); - configProperties.put( - CalciteConnectionProperty.QUOTED_CASING.camelName(), Casing.TO_LOWER.toString()); + configProperties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(), + Boolean.TRUE.toString()); + configProperties.put(CalciteConnectionProperty.UNQUOTED_CASING.camelName(), + Casing.UNCHANGED.toString()); + configProperties.put(CalciteConnectionProperty.QUOTED_CASING.camelName(), + Casing.TO_LOWER.toString()); } public static SqlParser.Config getParserConfig(EngineType engineType) { @@ -76,15 +75,10 @@ public class Configuration { parserConfig.setQuotedCasing(config.quotedCasing()); parserConfig.setConformance(config.conformance()); parserConfig.setLex(Lex.BIG_QUERY); - parserConfig - .setParserFactory(SqlParserImpl.FACTORY) - .setCaseSensitive(false) - .setIdentifierMaxLength(Integer.MAX_VALUE) - .setQuoting(Quoting.BACK_TICK) - .setQuoting(Quoting.SINGLE_QUOTE) - .setQuotedCasing(Casing.TO_UPPER) - .setUnquotedCasing(Casing.TO_UPPER) - .setConformance(sqlDialect.getConformance()) + parserConfig.setParserFactory(SqlParserImpl.FACTORY).setCaseSensitive(false) + .setIdentifierMaxLength(Integer.MAX_VALUE).setQuoting(Quoting.BACK_TICK) + .setQuoting(Quoting.SINGLE_QUOTE).setQuotedCasing(Casing.TO_UPPER) + .setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance()) .setLex(Lex.BIG_QUERY); parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED); parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED); @@ -96,61 +90,39 @@ public class Configuration { tables.add(SqlStdOperatorTable.instance()); SqlOperatorTable operatorTable = new ChainedSqlOperatorTable(tables); // operatorTable. - Prepare.CatalogReader catalogReader = - new CalciteCatalogReader( - rootSchema, - Collections.singletonList(rootSchema.getName()), - typeFactory, - config); - return SqlValidatorUtil.newValidator( - operatorTable, - catalogReader, - typeFactory, + Prepare.CatalogReader catalogReader = new CalciteCatalogReader(rootSchema, + Collections.singletonList(rootSchema.getName()), typeFactory, config); + return SqlValidatorUtil.newValidator(operatorTable, catalogReader, typeFactory, Configuration.getValidatorConfig(engineType)); } - public static SqlValidatorWithHints getSqlValidatorWithHints( - CalciteSchema rootSchema, EngineType engineTyp) { - return new SqlAdvisorValidator( - SqlStdOperatorTable.instance(), - new CalciteCatalogReader( - rootSchema, - Collections.singletonList(rootSchema.getName()), - typeFactory, - config), - typeFactory, - SqlValidator.Config.DEFAULT); + public static SqlValidatorWithHints getSqlValidatorWithHints(CalciteSchema rootSchema, + EngineType engineTyp) { + return new SqlAdvisorValidator(SqlStdOperatorTable.instance(), + new CalciteCatalogReader(rootSchema, + Collections.singletonList(rootSchema.getName()), typeFactory, config), + typeFactory, SqlValidator.Config.DEFAULT); } public static SqlToRelConverter.Config getConverterConfig() { HintStrategyTable strategies = HintStrategyTable.builder().build(); - return SqlToRelConverter.config() - .withHintStrategyTable(strategies) - .withTrimUnusedFields(true) - .withExpand(true) + return SqlToRelConverter.config().withHintStrategyTable(strategies) + .withTrimUnusedFields(true).withExpand(true) .addRelBuilderConfigTransform(c -> c.withSimplify(false)); } - public static SqlToRelConverter getSqlToRelConverter( - SqlValidatorScope scope, - SqlValidator sqlValidator, - RelOptPlanner relOptPlanner, - EngineType engineType) { + public static SqlToRelConverter getSqlToRelConverter(SqlValidatorScope scope, + SqlValidator sqlValidator, RelOptPlanner relOptPlanner, EngineType engineType) { RexBuilder rexBuilder = new RexBuilder(typeFactory); RelOptCluster cluster = RelOptCluster.create(relOptPlanner, rexBuilder); FrameworkConfig fromworkConfig = - Frameworks.newConfigBuilder() - .parserConfig(getParserConfig(engineType)) + Frameworks.newConfigBuilder().parserConfig(getParserConfig(engineType)) .defaultSchema( scope.getValidator().getCatalogReader().getRootSchema().plus()) .build(); - return new SqlToRelConverter( - new ViewExpanderImpl(), - sqlValidator, - (CatalogReader) scope.getValidator().getCatalogReader(), - cluster, - fromworkConfig.getConvertletTable(), - getConverterConfig()); + return new SqlToRelConverter(new ViewExpanderImpl(), sqlValidator, + (CatalogReader) scope.getValidator().getCatalogReader(), cluster, + fromworkConfig.getConvertletTable(), getConverterConfig()); } public static SqlAdvisor getSqlAdvisor(SqlValidatorWithHints validator, EngineType engineType) { @@ -159,15 +131,10 @@ public class Configuration { public static SqlWriterConfig getSqlWriterConfig(EngineType engineType) { SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType); - SqlWriterConfig config = - SqlPrettyWriter.config() - .withDialect(sqlDialect) - .withKeywordsLowerCase(false) - .withClauseEndsLine(true) - .withAlwaysUseParentheses(false) - .withSelectListItemsOnSeparateLines(false) - .withUpdateSetListNewline(false) - .withIndentation(0); + SqlWriterConfig config = SqlPrettyWriter.config().withDialect(sqlDialect) + .withKeywordsLowerCase(false).withClauseEndsLine(true) + .withAlwaysUseParentheses(false).withSelectListItemsOnSeparateLines(false) + .withUpdateSetListNewline(false).withIndentation(0); if (EngineType.MYSQL.equals(engineType)) { // no backticks around function name config = config.withQuoteAllIdentifiers(false); diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SemanticSqlDialect.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SemanticSqlDialect.java index 42938a6bf..153aa5e46 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SemanticSqlDialect.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SemanticSqlDialect.java @@ -17,8 +17,8 @@ public class SemanticSqlDialect extends SqlDialect { super(context); } - public static void unparseFetchUsingAnsi( - SqlWriter writer, @Nullable SqlNode offset, @Nullable SqlNode fetch) { + public static void unparseFetchUsingAnsi(SqlWriter writer, @Nullable SqlNode offset, + @Nullable SqlNode fetch) { Preconditions.checkArgument(fetch != null || offset != null); SqlWriter.Frame fetchFrame; writer.newlineAndIndent(); @@ -74,11 +74,11 @@ public class SemanticSqlDialect extends SqlDialect { return true; } - public void unparseSqlIntervalLiteral( - SqlWriter writer, SqlIntervalLiteral literal, int leftPrec, int rightPrec) {} + public void unparseSqlIntervalLiteral(SqlWriter writer, SqlIntervalLiteral literal, + int leftPrec, int rightPrec) {} - public void unparseOffsetFetch( - SqlWriter writer, @Nullable SqlNode offset, @Nullable SqlNode fetch) { + public void unparseOffsetFetch(SqlWriter writer, @Nullable SqlNode offset, + @Nullable SqlNode fetch) { unparseFetchUsingAnsi(writer, offset, fetch); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java index 14cd22796..1abdb1fd1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlDialectFactory.java @@ -13,22 +13,14 @@ import java.util.Objects; public class SqlDialectFactory { public static final Context DEFAULT_CONTEXT = - SqlDialect.EMPTY_CONTEXT - .withDatabaseProduct(DatabaseProduct.BIG_QUERY) - .withLiteralQuoteString("'") - .withLiteralEscapedQuoteString("''") - .withIdentifierQuoteString("`") - .withUnquotedCasing(Casing.UNCHANGED) - .withQuotedCasing(Casing.UNCHANGED) - .withCaseSensitive(false); - public static final Context POSTGRESQL_CONTEXT = - SqlDialect.EMPTY_CONTEXT - .withDatabaseProduct(DatabaseProduct.BIG_QUERY) - .withLiteralQuoteString("'") - .withLiteralEscapedQuoteString("''") - .withUnquotedCasing(Casing.UNCHANGED) - .withQuotedCasing(Casing.UNCHANGED) - .withCaseSensitive(false); + SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY) + .withLiteralQuoteString("'").withLiteralEscapedQuoteString("''") + .withIdentifierQuoteString("`").withUnquotedCasing(Casing.UNCHANGED) + .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false); + public static final Context POSTGRESQL_CONTEXT = SqlDialect.EMPTY_CONTEXT + .withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'") + .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED) + .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false); private static Map sqlDialectMap; static { diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java index 92a8ec15d..a0c375698 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlMergeWithUtils.java @@ -20,12 +20,8 @@ import java.util.List; @Slf4j public class SqlMergeWithUtils { - public static String mergeWith( - EngineType engineType, - String sql, - List parentSqlList, - List parentWithNameList) - throws SqlParseException { + public static String mergeWith(EngineType engineType, String sql, List parentSqlList, + List parentWithNameList) throws SqlParseException { SqlParser.Config parserConfig = Configuration.getParserConfig(engineType); // Parse the main SQL statement @@ -45,14 +41,12 @@ public class SqlMergeWithUtils { SqlNode sqlNode2 = parser.parseQuery(); // Create a new WITH item for parentWithName without quotes - SqlWithItem withItem = - new SqlWithItem( - SqlParserPos.ZERO, - new SqlIdentifier( - parentWithName, SqlParserPos.ZERO), // false to avoid quotes - null, - sqlNode2, - SqlLiteral.createBoolean(false, SqlParserPos.ZERO)); + SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO, + new SqlIdentifier(parentWithName, SqlParserPos.ZERO), // false + // to + // avoid + // quotes + null, sqlNode2, SqlLiteral.createBoolean(false, SqlParserPos.ZERO)); // Add the new WITH item to the list withItemList.add(withItem); @@ -66,11 +60,8 @@ public class SqlMergeWithUtils { } // Create a new SqlWith node - SqlWith finalSqlNode = - new SqlWith( - SqlParserPos.ZERO, - new SqlNodeList(withItemList, SqlParserPos.ZERO), - sqlNode1); + SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO, + new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1); // Custom SqlPrettyWriter configuration to avoid quoting identifiers SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType); // Pretty print the final SQL diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlParseUtils.java b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlParseUtils.java index 42b040d87..26a71c4c5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/SqlParseUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/SqlParseUtils.java @@ -45,10 +45,8 @@ public class SqlParseUtils { sqlParserInfo.setAllFields( sqlParserInfo.getAllFields().stream().distinct().collect(Collectors.toList())); - sqlParserInfo.setSelectFields( - sqlParserInfo.getSelectFields().stream() - .distinct() - .collect(Collectors.toList())); + sqlParserInfo.setSelectFields(sqlParserInfo.getSelectFields().stream().distinct() + .collect(Collectors.toList())); return sqlParserInfo; } catch (SqlParseException e) { @@ -108,13 +106,10 @@ public class SqlParseUtils { SqlSelect sqlSelect = (SqlSelect) select; SqlNodeList selectList = sqlSelect.getSelectList(); - selectList - .getList() - .forEach( - list -> { - Set selectFields = handlerField(list); - sqlParserInfo.getSelectFields().addAll(selectFields); - }); + selectList.getList().forEach(list -> { + Set selectFields = handlerField(list); + sqlParserInfo.getSelectFields().addAll(selectFields); + }); String tableName = handlerFrom(sqlSelect.getFrom()); sqlParserInfo.setTableName(tableName); @@ -129,14 +124,10 @@ public class SqlParseUtils { results.addAll(formFields); } - sqlSelect - .getSelectList() - .getList() - .forEach( - list -> { - Set selectFields = handlerField(list); - results.addAll(selectFields); - }); + sqlSelect.getSelectList().getList().forEach(list -> { + Set selectFields = handlerField(list); + results.addAll(selectFields); + }); if (sqlSelect.hasWhere()) { Set whereFields = handlerField(sqlSelect.getWhere()); @@ -148,11 +139,10 @@ public class SqlParseUtils { } SqlNodeList group = sqlSelect.getGroup(); if (group != null) { - group.forEach( - groupField -> { - Set groupByFields = handlerField(groupField); - results.addAll(groupByFields); - }); + group.forEach(groupField -> { + Set groupByFields = handlerField(groupField); + results.addAll(groupByFields); + }); } return results; } @@ -213,12 +203,9 @@ public class SqlParseUtils { } } if (field instanceof SqlNodeList) { - ((SqlNodeList) field) - .getList() - .forEach( - node -> { - fields.addAll(handlerField(node)); - }); + ((SqlNodeList) field).getList().forEach(node -> { + fields.addAll(handlerField(node)); + }); } break; } @@ -243,12 +230,9 @@ public class SqlParseUtils { SqlIdentifier sqlIdentifier = (SqlIdentifier) operandList.get(0); String simple = sqlIdentifier.getSimple(); SqlBasicCall aliasedNode = - new SqlBasicCall( - SqlStdOperatorTable.AS, - new SqlNode[] { - sqlBasicCall, - new SqlIdentifier(simple.toLowerCase(), SqlParserPos.ZERO) - }, + new SqlBasicCall(SqlStdOperatorTable.AS, + new SqlNode[] {sqlBasicCall, new SqlIdentifier( + simple.toLowerCase(), SqlParserPos.ZERO)}, SqlParserPos.ZERO); selectList.set(selectList.indexOf(node), aliasedNode); } diff --git a/common/src/main/java/com/tencent/supersonic/common/calcite/ViewExpanderImpl.java b/common/src/main/java/com/tencent/supersonic/common/calcite/ViewExpanderImpl.java index 419627f10..9fc6366c9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/calcite/ViewExpanderImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/calcite/ViewExpanderImpl.java @@ -11,10 +11,7 @@ public class ViewExpanderImpl implements RelOptTable.ViewExpander { public ViewExpanderImpl() {} @Override - public RelRoot expandView( - RelDataType rowType, - String queryString, - List schemaPath, + public RelRoot expandView(RelDataType rowType, String queryString, List schemaPath, List dataSetPath) { return null; } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java index 97efa3b4f..c61ac0596 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ChatModelParameterConfig.java @@ -20,98 +20,37 @@ import java.util.List; @Slf4j public class ChatModelParameterConfig extends ParameterConfig { - public static final Parameter CHAT_MODEL_PROVIDER = - new Parameter( - "s2.chat.model.provider", - OpenAiModelFactory.PROVIDER, - "接口协议", - "", - "list", - "对话模型配置", - getCandidateValues()); + public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider", + OpenAiModelFactory.PROVIDER, "接口协议", "", "list", "对话模型配置", getCandidateValues()); public static final Parameter CHAT_MODEL_BASE_URL = - new Parameter( - "s2.chat.model.base.url", - OpenAiModelFactory.DEFAULT_BASE_URL, - "BaseUrl", - "", - "string", - "对话模型配置", - null, - getBaseUrlDependency()); - public static final Parameter CHAT_MODEL_ENDPOINT = - new Parameter( - "s2.chat.model.endpoint", - "llama_2_70b", - "Endpoint", - "", - "string", - "对话模型配置", - null, - getEndpointDependency()); - public static final Parameter CHAT_MODEL_API_KEY = - new Parameter( - "s2.chat.model.api.key", - DEMO, - "ApiKey", - "", - "password", - "对话模型配置", - null, - getApiKeyDependency()); - public static final Parameter CHAT_MODEL_SECRET_KEY = - new Parameter( - "s2.chat.model.secretKey", - "demo", - "SecretKey", - "", - "password", - "对话模型配置", - null, - getSecretKeyDependency()); + new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl", + "", "string", "对话模型配置", null, getBaseUrlDependency()); + public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("s2.chat.model.endpoint", + "llama_2_70b", "Endpoint", "", "string", "对话模型配置", null, getEndpointDependency()); + public static final Parameter CHAT_MODEL_API_KEY = new Parameter("s2.chat.model.api.key", DEMO, + "ApiKey", "", "password", "对话模型配置", null, getApiKeyDependency()); + public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("s2.chat.model.secretKey", + "demo", "SecretKey", "", "password", "对话模型配置", null, getSecretKeyDependency()); - public static final Parameter CHAT_MODEL_NAME = - new Parameter( - "s2.chat.model.name", - "gpt-4o-mini", - "ModelName", - "", - "string", - "对话模型配置", - null, - getModelNameDependency()); + public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name", + "gpt-4o-mini", "ModelName", "", "string", "对话模型配置", null, getModelNameDependency()); public static final Parameter CHAT_MODEL_ENABLE_SEARCH = - new Parameter( - "s2.chat.model.enableSearch", - "false", - "是否启用搜索增强功能,设为false表示不启用", - "", - "bool", - "对话模型配置", - null, - getEnableSearchDependency()); + new Parameter("s2.chat.model.enableSearch", "false", "是否启用搜索增强功能,设为false表示不启用", "", + "bool", "对话模型配置", null, getEnableSearchDependency()); - public static final Parameter CHAT_MODEL_TEMPERATURE = - new Parameter( - "s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置"); + public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter( + "s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置"); public static final Parameter CHAT_MODEL_TIMEOUT = new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置"); @Override public List getSysParameters() { - return Lists.newArrayList( - CHAT_MODEL_PROVIDER, - CHAT_MODEL_BASE_URL, - CHAT_MODEL_ENDPOINT, - CHAT_MODEL_API_KEY, - CHAT_MODEL_SECRET_KEY, - CHAT_MODEL_NAME, - CHAT_MODEL_ENABLE_SEARCH, - CHAT_MODEL_TEMPERATURE, - CHAT_MODEL_TIMEOUT); + return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT, + CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME, + CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT); } public ChatModelConfig convert() { @@ -125,36 +64,24 @@ public class ChatModelParameterConfig extends ParameterConfig { String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY); String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH); - return ChatModelConfig.builder() - .provider(chatModelProvider) - .baseUrl(chatModelBaseUrl) - .apiKey(chatModelApiKey) - .modelName(chatModelName) + return ChatModelConfig.builder().provider(chatModelProvider).baseUrl(chatModelBaseUrl) + .apiKey(chatModelApiKey).modelName(chatModelName) .enableSearch(Boolean.valueOf(enableSearch)) .temperature(Double.valueOf(chatModelTemperature)) - .timeOut(Long.valueOf(chatModelTimeout)) - .endpoint(endpoint) - .secretKey(secretKey) + .timeOut(Long.valueOf(chatModelTimeout)).endpoint(endpoint).secretKey(secretKey) .build(); } private static List getCandidateValues() { - return Lists.newArrayList( - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER, - LocalAiModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, + return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER, + LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER, AzureModelFactory.PROVIDER); } private static List getBaseUrlDependency() { - return getDependency( - CHAT_MODEL_PROVIDER.getName(), - getCandidateValues(), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL, + return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(), + ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL, AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL, QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL, @@ -164,30 +91,18 @@ public class ChatModelParameterConfig extends ParameterConfig { } private static List getApiKeyDependency() { - return getDependency( - CHAT_MODEL_PROVIDER.getName(), - Lists.newArrayList( - OpenAiModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER, - LocalAiModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, DEMO, - QianfanModelFactory.PROVIDER, DEMO, - ZhipuModelFactory.PROVIDER, DEMO, - LocalAiModelFactory.PROVIDER, DEMO, - AzureModelFactory.PROVIDER, DEMO, - DashscopeModelFactory.PROVIDER, DEMO)); + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, + ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER), + ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, + DEMO, ZhipuModelFactory.PROVIDER, DEMO, LocalAiModelFactory.PROVIDER, DEMO, + AzureModelFactory.PROVIDER, DEMO, DashscopeModelFactory.PROVIDER, DEMO)); } private static List getModelNameDependency() { - return getDependency( - CHAT_MODEL_PROVIDER.getName(), - getCandidateValues(), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME, + return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(), + ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME, OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME, QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME, ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME, @@ -197,23 +112,19 @@ public class ChatModelParameterConfig extends ParameterConfig { } private static List getEndpointDependency() { - return getDependency( - CHAT_MODEL_PROVIDER.getName(), - Lists.newArrayList(QianfanModelFactory.PROVIDER), - ImmutableMap.of( - QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)); + return getDependency(CHAT_MODEL_PROVIDER.getName(), + Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap + .of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT)); } private static List getEnableSearchDependency() { - return getDependency( - CHAT_MODEL_PROVIDER.getName(), + return getDependency(CHAT_MODEL_PROVIDER.getName(), Lists.newArrayList(DashscopeModelFactory.PROVIDER), ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false")); } private static List getSecretKeyDependency() { - return getDependency( - CHAT_MODEL_PROVIDER.getName(), + return getDependency(CHAT_MODEL_PROVIDER.getName(), Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)); } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java index 7344de23c..e06431eac 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingModelParameterConfig.java @@ -22,89 +22,35 @@ import java.util.List; @Slf4j public class EmbeddingModelParameterConfig extends ParameterConfig { public static final Parameter EMBEDDING_MODEL_PROVIDER = - new Parameter( - "s2.embedding.model.provider", - InMemoryModelFactory.PROVIDER, - "接口协议", - "", - "list", - "向量模型配置", - getCandidateValues()); + new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "", + "list", "向量模型配置", getCandidateValues()); public static final Parameter EMBEDDING_MODEL_BASE_URL = - new Parameter( - "s2.embedding.model.base.url", - "", - "BaseUrl", - "", - "string", - "向量模型配置", - null, - getBaseUrlDependency()); + new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", "向量模型配置", + null, getBaseUrlDependency()); public static final Parameter EMBEDDING_MODEL_API_KEY = - new Parameter( - "s2.embedding.model.api.key", - "", - "ApiKey", - "", - "password", - "向量模型配置", - null, - getApiKeyDependency()); + new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", "向量模型配置", + null, getApiKeyDependency()); public static final Parameter EMBEDDING_MODEL_SECRET_KEY = - new Parameter( - "s2.embedding.model.secretKey", - "demo", - "SecretKey", - "", - "password", - "向量模型配置", - null, - getSecretKeyDependency()); + new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password", + "向量模型配置", null, getSecretKeyDependency()); public static final Parameter EMBEDDING_MODEL_NAME = - new Parameter( - "s2.embedding.model.name", - EmbeddingModelConstant.BGE_SMALL_ZH, - "ModelName", - "", - "string", - "向量模型配置", - null, - getModelNameDependency()); + new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH, + "ModelName", "", "string", "向量模型配置", null, getModelNameDependency()); - public static final Parameter EMBEDDING_MODEL_PATH = - new Parameter( - "s2.embedding.model.path", - "", - "模型路径", - "", - "string", - "向量模型配置", - null, - getModelPathDependency()); + public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path", + "", "模型路径", "", "string", "向量模型配置", null, getModelPathDependency()); public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = - new Parameter( - "s2.embedding.model.vocabulary.path", - "", - "词汇表路径", - "", - "string", - "向量模型配置", - null, - getModelPathDependency()); + new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string", "向量模型配置", + null, getModelPathDependency()); @Override public List getSysParameters() { - return Lists.newArrayList( - EMBEDDING_MODEL_PROVIDER, - EMBEDDING_MODEL_BASE_URL, - EMBEDDING_MODEL_API_KEY, - EMBEDDING_MODEL_SECRET_KEY, - EMBEDDING_MODEL_NAME, - EMBEDDING_MODEL_PATH, - EMBEDDING_MODEL_VOCABULARY_PATH); + return Lists.newArrayList(EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL, + EMBEDDING_MODEL_API_KEY, EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME, + EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH); } public EmbeddingModelConfig convert() { @@ -115,40 +61,24 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { String modelPath = getParameterValue(EMBEDDING_MODEL_PATH); String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH); String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY); - return EmbeddingModelConfig.builder() - .provider(provider) - .baseUrl(baseUrl) - .apiKey(apiKey) - .secretKey(secretKey) - .modelName(modelName) - .modelPath(modelPath) - .vocabularyPath(vocabularyPath) - .build(); + return EmbeddingModelConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey) + .secretKey(secretKey).modelName(modelName).modelPath(modelPath) + .vocabularyPath(vocabularyPath).build(); } private static ArrayList getCandidateValues() { - return Lists.newArrayList( - InMemoryModelFactory.PROVIDER, - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER, + return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER, AzureModelFactory.PROVIDER); } private static List getBaseUrlDependency() { - return getDependency( - EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, - ZhipuModelFactory.PROVIDER), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL, + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER, + AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER, + QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER), + ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL, OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL, AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL, @@ -157,63 +87,43 @@ public class EmbeddingModelParameterConfig extends ParameterConfig { } private static List getApiKeyDependency() { - return getDependency( - EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - OpenAiModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER), - ImmutableMap.of( - OpenAiModelFactory.PROVIDER, - DEMO, - AzureModelFactory.PROVIDER, - DEMO, - DashscopeModelFactory.PROVIDER, - DEMO, - QianfanModelFactory.PROVIDER, - DEMO, - ZhipuModelFactory.PROVIDER, - DEMO)); + ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO, + DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO, + ZhipuModelFactory.PROVIDER, DEMO)); } private static List getModelNameDependency() { - return getDependency( - EMBEDDING_MODEL_PROVIDER.getName(), - Lists.newArrayList( - InMemoryModelFactory.PROVIDER, - OpenAiModelFactory.PROVIDER, - OllamaModelFactory.PROVIDER, - AzureModelFactory.PROVIDER, - DashscopeModelFactory.PROVIDER, - QianfanModelFactory.PROVIDER, + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), + Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER, + OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER, + DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER), - ImmutableMap.of( - InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH, + ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH, OpenAiModelFactory.PROVIDER, - OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, + OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, OllamaModelFactory.PROVIDER, - OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, - AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, + OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER, + AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, DashscopeModelFactory.PROVIDER, - DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, + DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, QianfanModelFactory.PROVIDER, - QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, + QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, ZhipuModelFactory.PROVIDER, - ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME)); + ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME)); } private static List getModelPathDependency() { - return getDependency( - EMBEDDING_MODEL_PROVIDER.getName(), + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), Lists.newArrayList(InMemoryModelFactory.PROVIDER), ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")); } private static List getSecretKeyDependency() { - return getDependency( - EMBEDDING_MODEL_PROVIDER.getName(), + return getDependency(EMBEDDING_MODEL_PROVIDER.getName(), Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)); } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java index 7284fe189..6ca2bad3d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/EmbeddingStoreParameterConfig.java @@ -15,83 +15,38 @@ import java.util.List; @Service("EmbeddingStoreParameterConfig") @Slf4j public class EmbeddingStoreParameterConfig extends ParameterConfig { - public static final Parameter EMBEDDING_STORE_PROVIDER = - new Parameter( - "s2.embedding.store.provider", - EmbeddingStoreType.IN_MEMORY.name(), - "向量库类型", - "目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", - "list", - "向量库配置", - getCandidateValues()); + public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter( + "s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型", + "目前支持三种类型:IN_MEMORY、MILVUS、CHROMA", "list", "向量库配置", getCandidateValues()); public static final Parameter EMBEDDING_STORE_BASE_URL = - new Parameter( - "s2.embedding.store.base.url", - "", - "BaseUrl", - "", - "string", - "向量库配置", - null, + new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", "向量库配置", null, getBaseUrlDependency()); public static final Parameter EMBEDDING_STORE_API_KEY = - new Parameter( - "s2.embedding.store.api.key", - "", - "ApiKey", - "", - "password", - "向量库配置", - null, + new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", "向量库配置", null, getApiKeyDependency()); public static final Parameter EMBEDDING_STORE_PERSIST_PATH = - new Parameter( - "s2.embedding.store.persist.path", - "", - "持久化路径", - "默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", - "string", - "向量库配置", - null, - getPathDependency()); + new Parameter("s2.embedding.store.persist.path", "", "持久化路径", + "默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string", + "向量库配置", null, getPathDependency()); public static final Parameter EMBEDDING_STORE_TIMEOUT = new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置"); public static final Parameter EMBEDDING_STORE_DIMENSION = - new Parameter( - "s2.embedding.store.dimension", - "", - "纬度", - "", - "number", - "向量库配置", - null, + new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", "向量库配置", null, getDimensionDependency()); public static final Parameter EMBEDDING_STORE_DATABASE_NAME = - new Parameter( - "s2.embedding.store.databaseName", - "", - "DatabaseName", - "", - "string", - "向量库配置", - null, - getDatabaseNameDependency()); + new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string", + "向量库配置", null, getDatabaseNameDependency()); @Override public List getSysParameters() { - return Lists.newArrayList( - EMBEDDING_STORE_PROVIDER, - EMBEDDING_STORE_BASE_URL, - EMBEDDING_STORE_API_KEY, - EMBEDDING_STORE_DATABASE_NAME, - EMBEDDING_STORE_PERSIST_PATH, - EMBEDDING_STORE_TIMEOUT, - EMBEDDING_STORE_DIMENSION); + return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL, + EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME, + EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION); } public EmbeddingStoreConfig convert() { @@ -105,58 +60,44 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig { if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) { dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION)); } - return EmbeddingStoreConfig.builder() - .provider(provider) - .baseUrl(baseUrl) - .apiKey(apiKey) - .persistPath(persistPath) - .databaseName(databaseName) - .timeOut(Long.valueOf(timeOut)) - .dimension(dimension) - .build(); + return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey) + .persistPath(persistPath).databaseName(databaseName).timeOut(Long.valueOf(timeOut)) + .dimension(dimension).build(); } private static ArrayList getCandidateValues() { - return Lists.newArrayList( - EmbeddingStoreType.IN_MEMORY.name(), - EmbeddingStoreType.MILVUS.name(), - EmbeddingStoreType.CHROMA.name()); + return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(), + EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()); } private static List getBaseUrlDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), - Lists.newArrayList( - EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()), - ImmutableMap.of( - EmbeddingStoreType.MILVUS.name(), "http://localhost:19530", + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), + Lists.newArrayList(EmbeddingStoreType.MILVUS.name(), + EmbeddingStoreType.CHROMA.name()), + ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530", EmbeddingStoreType.CHROMA.name(), "http://localhost:8000")); } private static List getApiKeyDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)); } private static List getPathDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()), ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), "")); } private static List getDimensionDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")); } private static List getDatabaseNameDependency() { - return getDependency( - EMBEDDING_STORE_PROVIDER.getName(), + return getDependency(EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "")); } diff --git a/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java index f74b319a1..e11d035b8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/ParameterConfig.java @@ -15,9 +15,11 @@ import java.util.Map; @Service public abstract class ParameterConfig { public static final String DEMO = "demo"; - @Autowired private SystemConfigService sysConfigService; + @Autowired + private SystemConfigService sysConfigService; - @Autowired private Environment environment; + @Autowired + private Environment environment; /** @return system parameters to be set with user interface */ protected List getSysParameters() { @@ -46,10 +48,8 @@ public abstract class ParameterConfig { return value; } - protected static List getDependency( - String dependencyParameterName, - List includesValue, - Map setDefaultValue) { + protected static List getDependency(String dependencyParameterName, + List includesValue, Map setDefaultValue) { Parameter.Dependency.Show show = new Parameter.Dependency.Show(); show.setIncludesValue(includesValue); diff --git a/common/src/main/java/com/tencent/supersonic/common/config/SystemConfig.java b/common/src/main/java/com/tencent/supersonic/common/config/SystemConfig.java index 1599346a3..a37185721 100644 --- a/common/src/main/java/com/tencent/supersonic/common/config/SystemConfig.java +++ b/common/src/main/java/com/tencent/supersonic/common/config/SystemConfig.java @@ -38,11 +38,8 @@ public class SystemConfig { if (StringUtils.isBlank(name)) { return ""; } - Map nameToValue = - getParameters().stream() - .collect( - Collectors.toMap( - Parameter::getName, Parameter::getValue, (k1, k2) -> k1)); + Map nameToValue = getParameters().stream() + .collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (k1, k2) -> k1)); return nameToValue.get(name); } @@ -69,15 +66,11 @@ public class SystemConfig { if (CollectionUtils.isEmpty(parameters)) { return defaultParameters; } - Map parameterNameValueMap = - parameters.stream() - .collect( - Collectors.toMap( - Parameter::getName, Parameter::getValue, (v1, v2) -> v2)); + Map parameterNameValueMap = parameters.stream() + .collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (v1, v2) -> v2)); for (Parameter parameter : defaultParameters) { - parameter.setValue( - parameterNameValueMap.getOrDefault( - parameter.getName(), parameter.getDefaultValue())); + parameter.setValue(parameterNameValueMap.getOrDefault(parameter.getName(), + parameter.getDefaultValue())); } return defaultParameters; } diff --git a/common/src/main/java/com/tencent/supersonic/common/interceptor/LogInterceptor.java b/common/src/main/java/com/tencent/supersonic/common/interceptor/LogInterceptor.java index b18c92b93..2c5dfb848 100644 --- a/common/src/main/java/com/tencent/supersonic/common/interceptor/LogInterceptor.java +++ b/common/src/main/java/com/tencent/supersonic/common/interceptor/LogInterceptor.java @@ -14,8 +14,8 @@ import org.springframework.web.servlet.ModelAndView; @Slf4j public class LogInterceptor implements HandlerInterceptor { @Override - public boolean preHandle( - HttpServletRequest request, HttpServletResponse response, Object handler) { + public boolean preHandle(HttpServletRequest request, HttpServletResponse response, + Object handler) { // use previous traceId String traceId = request.getHeader(TraceIdUtil.TRACE_ID); if (StringUtils.isBlank(traceId)) { @@ -27,17 +27,12 @@ public class LogInterceptor implements HandlerInterceptor { } @Override - public void postHandle( - HttpServletRequest request, - HttpServletResponse response, - Object handler, - ModelAndView modelAndView) - throws Exception {} + public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, + ModelAndView modelAndView) throws Exception {} @Override - public void afterCompletion( - HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) - throws Exception { + public void afterCompletion(HttpServletRequest request, HttpServletResponse response, + Object handler, Exception ex) throws Exception { // remove after Completing TraceIdUtil.remove(); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java index 8f0cdb6e5..a0a5baee5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/AggregateEnum.java @@ -5,13 +5,9 @@ import java.util.Map; import java.util.stream.Collectors; public enum AggregateEnum { - MOST("最多", "max"), - HIGHEST("最高", "max"), - MAXIMUN("最大", "max"), - LEAST("最少", "min"), - SMALLEST("最小", "min"), - LOWEST("最低", "min"), - AVERAGE("平均", "avg"); + MOST("最多", "max"), HIGHEST("最高", "max"), MAXIMUN("最大", "max"), LEAST("最少", + "min"), SMALLEST("最小", "min"), LOWEST("最低", "min"), AVERAGE("平均", "avg"); + private String aggregateCh; private String aggregateEN; @@ -29,9 +25,7 @@ public enum AggregateEnum { } public static Map getAggregateEnum() { - return Arrays.stream(AggregateEnum.values()) - .collect( - Collectors.toMap( - AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN)); + return Arrays.stream(AggregateEnum.values()).collect( + Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN)); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/CustomExpressionDeParser.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/CustomExpressionDeParser.java index d7b61118e..31e7bb8a8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/CustomExpressionDeParser.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/CustomExpressionDeParser.java @@ -15,8 +15,8 @@ public class CustomExpressionDeParser extends ExpressionDeParser { private boolean dealNull; private boolean dealNotNull; - public CustomExpressionDeParser( - Set removeFieldNames, boolean dealNull, boolean dealNotNull) { + public CustomExpressionDeParser(Set removeFieldNames, boolean dealNull, + boolean dealNotNull) { this.removeFieldNames = removeFieldNames; this.dealNull = dealNull; this.dealNotNull = dealNotNull; @@ -45,12 +45,10 @@ public class CustomExpressionDeParser extends ExpressionDeParser { Expression leftExpression = ((AndExpression) binaryExpression).getLeftExpression(); Expression rightExpression = ((AndExpression) binaryExpression).getRightExpression(); - boolean leftIsNull = - leftExpression instanceof IsNullExpression - && shouldSkip((IsNullExpression) leftExpression); - boolean rightIsNull = - rightExpression instanceof IsNullExpression - && shouldSkip((IsNullExpression) rightExpression); + boolean leftIsNull = leftExpression instanceof IsNullExpression + && shouldSkip((IsNullExpression) leftExpression); + boolean rightIsNull = rightExpression instanceof IsNullExpression + && shouldSkip((IsNullExpression) rightExpression); if (leftIsNull && rightIsNull) { // Skip both expressions diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/DateFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/DateFunctionHelper.java index aad1e73a9..850ac16e8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/DateFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/DateFunctionHelper.java @@ -13,8 +13,8 @@ import net.sf.jsqlparser.expression.operators.relational.ExpressionList; @Slf4j public class DateFunctionHelper { - public static String getStartDateStr( - ComparisonOperator minorThanEquals, ExpressionList expressions) { + public static String getStartDateStr(ComparisonOperator minorThanEquals, + ExpressionList expressions) { String unitValue = getUnit(expressions); String dateValue = getEndDateValue(expressions); String dateStr = ""; diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ExpressionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ExpressionReplaceVisitor.java index 7670e926e..639ad0e6a 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ExpressionReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ExpressionReplaceVisitor.java @@ -23,9 +23,8 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter { expr.getWhenExpression().accept(this); if (expr.getThenExpression() instanceof Column) { Column column = (Column) expr.getThenExpression(); - Expression expression = - QueryExpressionReplaceVisitor.getExpression( - QueryExpressionReplaceVisitor.getReplaceExpr(column, fieldExprMap)); + Expression expression = QueryExpressionReplaceVisitor.getExpression( + QueryExpressionReplaceVisitor.getReplaceExpr(column, fieldExprMap)); if (Objects.nonNull(expression)) { expr.setThenExpression(expression); } @@ -52,20 +51,16 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter { } } if (left instanceof Column) { - Expression expression = - QueryExpressionReplaceVisitor.getExpression( - QueryExpressionReplaceVisitor.getReplaceExpr( - (Column) left, fieldExprMap)); + Expression expression = QueryExpressionReplaceVisitor.getExpression( + QueryExpressionReplaceVisitor.getReplaceExpr((Column) left, fieldExprMap)); if (Objects.nonNull(expression)) { expr.setLeftExpression(expression); leftVisited = true; } } if (right instanceof Column) { - Expression expression = - QueryExpressionReplaceVisitor.getExpression( - QueryExpressionReplaceVisitor.getReplaceExpr( - (Column) right, fieldExprMap)); + Expression expression = QueryExpressionReplaceVisitor.getExpression( + QueryExpressionReplaceVisitor.getReplaceExpr((Column) right, fieldExprMap)); if (Objects.nonNull(expression)) { expr.setRightExpression(expression); rightVisited = true; @@ -81,9 +76,8 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter { private boolean visitFunction(Function function) { if (function.getParameters().getExpressions().get(0) instanceof Column) { - Expression expression = - QueryExpressionReplaceVisitor.getExpression( - QueryExpressionReplaceVisitor.getReplaceExpr(function, fieldExprMap)); + Expression expression = QueryExpressionReplaceVisitor.getExpression( + QueryExpressionReplaceVisitor.getReplaceExpr(function, fieldExprMap)); if (Objects.nonNull(expression)) { ExpressionList expressions = new ExpressionList<>(); expressions.add(expression); diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAndValueAcquireVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAndValueAcquireVisitor.java index 0f12865d9..008fc38a0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAndValueAcquireVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldAndValueAcquireVisitor.java @@ -130,8 +130,8 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter { Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList()); DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName); if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) { - fieldExpression.setFieldValue( - getFieldValue(rightExpression) + periodEnum.getChName()); + fieldExpression + .setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName()); return fieldExpression; } else { // deal with aggregate function diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java index c19db1633..182da9d00 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FieldValueReplaceVisitor.java @@ -31,8 +31,8 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter { private boolean exactReplace; private Map> filedNameToValueMap; - public FieldValueReplaceVisitor( - boolean exactReplace, Map> filedNameToValueMap) { + public FieldValueReplaceVisitor(boolean exactReplace, + Map> filedNameToValueMap) { this.exactReplace = exactReplace; this.filedNameToValueMap = filedNameToValueMap; } @@ -67,24 +67,20 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter { ExpressionList rightItemsList = (ExpressionList) inExpression.getRightExpression(); List expressions = rightItemsList.getExpressions(); List values = new ArrayList<>(); - expressions.stream() - .forEach( - o -> { - if (o instanceof StringValue) { - values.add(((StringValue) o).getValue()); - } - }); + expressions.stream().forEach(o -> { + if (o instanceof StringValue) { + values.add(((StringValue) o).getValue()); + } + }); if (valueMap == null || CollectionUtils.isEmpty(values)) { return; } List newExpressions = new ArrayList<>(); - values.stream() - .forEach( - o -> { - String replaceValue = valueMap.getOrDefault(o, o); - StringValue stringValue = new StringValue(replaceValue); - newExpressions.add(stringValue); - }); + values.stream().forEach(o -> { + String replaceValue = valueMap.getOrDefault(o, o); + StringValue stringValue = new StringValue(replaceValue); + newExpressions.add(stringValue); + }); rightItemsList.setExpressions(newExpressions); inExpression.setRightExpression(rightItemsList); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledNameReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledNameReplaceVisitor.java index 3d4fe836f..d66f2fb92 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledNameReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FiledNameReplaceVisitor.java @@ -34,11 +34,9 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter { Expression leftExpression = expr.getLeftExpression(); Expression rightExpression = expr.getRightExpression(); - if (!(rightExpression instanceof StringValue) - || !(leftExpression instanceof Column) + if (!(rightExpression instanceof StringValue) || !(leftExpression instanceof Column) || CollectionUtils.isEmpty(fieldValueToFieldNames) - || Objects.isNull(rightExpression) - || Objects.isNull(leftExpression)) { + || Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) { return; } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionAliasReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionAliasReplaceVisitor.java index 00e8aa9ce..303177d70 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionAliasReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionAliasReplaceVisitor.java @@ -21,8 +21,8 @@ public class FunctionAliasReplaceVisitor extends SelectItemVisitorAdapter { // 2.alias's fieldName not equal. "sum(pv) as pv" cannot be replaced. if (Objects.nonNull(selectExpressionItem.getAlias()) && !selectExpressionItem.getAlias().getName().equalsIgnoreCase(columnName)) { - aliasToActualExpression.put( - selectExpressionItem.getAlias().getName(), function.toString()); + aliasToActualExpression.put(selectExpressionItem.getAlias().getName(), + function.toString()); selectExpressionItem.setAlias(null); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionNameReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionNameReplaceVisitor.java index fca4d7965..15a9db57b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionNameReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/FunctionNameReplaceVisitor.java @@ -16,8 +16,8 @@ public class FunctionNameReplaceVisitor extends ExpressionVisitorAdapter { private Map functionMap; private Map functionCallMap; - public FunctionNameReplaceVisitor( - Map functionMap, Map functionCallMap) { + public FunctionNameReplaceVisitor(Map functionMap, + Map functionCallMap) { this.functionMap = functionMap; this.functionCallMap = functionCallMap; } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByFunctionReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByFunctionReplaceVisitor.java index ffad895d0..5e2c2b0b5 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByFunctionReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByFunctionReplaceVisitor.java @@ -19,8 +19,8 @@ public class GroupByFunctionReplaceVisitor implements GroupByVisitor { private Map functionMap; private Map functionCallMap; - public GroupByFunctionReplaceVisitor( - Map functionMap, Map functionCallMap) { + public GroupByFunctionReplaceVisitor(Map functionMap, + Map functionCallMap) { this.functionMap = functionMap; this.functionCallMap = functionCallMap; } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java index f63d88af2..dd442722c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/GroupByReplaceVisitor.java @@ -53,11 +53,8 @@ public class GroupByReplaceVisitor implements GroupByVisitor { return expression.toString(); } - private void replaceExpression( - List groupByExpressions, - int index, - Expression expression, - String replaceColumn) { + private void replaceExpression(List groupByExpressions, int index, + Expression expression, String replaceColumn) { if (expression instanceof Column) { groupByExpressions.set(index, new Column(replaceColumn)); } else if (expression instanceof Function) { @@ -68,8 +65,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor { Function function = (Function) expression; if (function.getParameters().size() > 1) { - function.getParameters().stream() - .skip(1) + function.getParameters().stream().skip(1) .forEach(e -> newExpressionList.add((Function) e)); } function.setParameters(newExpressionList); diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/JsqlConstants.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/JsqlConstants.java index 3d2cff49a..c368483a8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/JsqlConstants.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/JsqlConstants.java @@ -27,26 +27,14 @@ public class JsqlConstants { public static final String IN_CONSTANT = " 1 in (1) "; public static final String LIKE_CONSTANT = "1 like 1"; public static final String IN = "IN"; - public static final Map rightMap = - Stream.of( - new AbstractMap.SimpleEntry<>("<=", "<="), - new AbstractMap.SimpleEntry<>("<", "<"), - new AbstractMap.SimpleEntry<>(">=", "<="), - new AbstractMap.SimpleEntry<>(">", "<"), - new AbstractMap.SimpleEntry<>("=", "<=")) - .collect( - toMap( - AbstractMap.SimpleEntry::getKey, - AbstractMap.SimpleEntry::getValue)); - public static final Map leftMap = - Stream.of( - new AbstractMap.SimpleEntry<>("<=", ">="), - new AbstractMap.SimpleEntry<>("<", ">"), - new AbstractMap.SimpleEntry<>(">=", "<="), - new AbstractMap.SimpleEntry<>(">", "<"), - new AbstractMap.SimpleEntry<>("=", ">=")) - .collect( - toMap( - AbstractMap.SimpleEntry::getKey, - AbstractMap.SimpleEntry::getValue)); + public static final Map rightMap = Stream.of( + new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>("<", "<"), + new AbstractMap.SimpleEntry<>(">=", "<="), new AbstractMap.SimpleEntry<>(">", "<"), + new AbstractMap.SimpleEntry<>("=", "<=")) + .collect(toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue)); + public static final Map leftMap = Stream.of( + new AbstractMap.SimpleEntry<>("<=", ">="), new AbstractMap.SimpleEntry<>("<", ">"), + new AbstractMap.SimpleEntry<>(">=", "<="), new AbstractMap.SimpleEntry<>(">", "<"), + new AbstractMap.SimpleEntry<>("=", ">=")) + .collect(toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue)); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java index 3a6d19c50..37297a00d 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/ParseVisitorHelper.java @@ -13,8 +13,8 @@ import java.util.stream.Collectors; @Slf4j public class ParseVisitorHelper { - public void replaceColumn( - Column column, Map fieldNameMap, boolean exactReplace) { + public void replaceColumn(Column column, Map fieldNameMap, + boolean exactReplace) { String columnName = StringUtil.replaceBackticks(column.getColumnName()); String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace); if (StringUtils.isNotBlank(replaceColumn)) { @@ -22,8 +22,8 @@ public class ParseVisitorHelper { } } - public String getReplaceValue( - String beforeValue, Map valueMap, boolean exactReplace) { + public String getReplaceValue(String beforeValue, Map valueMap, + boolean exactReplace) { String value = valueMap.get(beforeValue); if (StringUtils.isNotBlank(value)) { return value; @@ -31,19 +31,13 @@ public class ParseVisitorHelper { if (exactReplace) { return null; } - Optional> first = - valueMap.entrySet().stream() - .sorted( - (k1, k2) -> { - String k1Value = k1.getKey(); - String k2Value = k2.getKey(); - Double k1Similarity = getSimilarity(beforeValue, k1Value); - Double k2Similarity = getSimilarity(beforeValue, k2Value); - return k2Similarity.compareTo(k1Similarity); - }) - .collect(Collectors.toList()) - .stream() - .findFirst(); + Optional> first = valueMap.entrySet().stream().sorted((k1, k2) -> { + String k1Value = k1.getKey(); + String k2Value = k2.getKey(); + Double k1Similarity = getSimilarity(beforeValue, k1Value); + Double k2Similarity = getSimilarity(beforeValue, k2Value); + return k2Similarity.compareTo(k1Similarity); + }).collect(Collectors.toList()).stream().findFirst(); if (first.isPresent()) { return first.get().getValue(); @@ -68,16 +62,12 @@ public class ParseVisitorHelper { char cj = word2.charAt(j - 1); if (ci == cj) { dp[i][j] = dp[i - 1][j - 1]; - } else if (i > 1 - && j > 1 - && ci == word2.charAt(j - 2) + } else if (i > 1 && j > 1 && ci == word2.charAt(j - 2) && cj == word1.charAt(i - 2)) { dp[i][j] = 1 + Math.min(dp[i - 2][j - 2], Math.min(dp[i][j - 1], dp[i - 1][j])); } else { - dp[i][j] = - Math.min( - dp[i - 1][j - 1] + 1, - Math.min(dp[i][j - 1] + 1, dp[i - 1][j] + 1)); + dp[i][j] = Math.min(dp[i - 1][j - 1] + 1, + Math.min(dp[i][j - 1] + 1, dp[i - 1][j] + 1)); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java index fe18d6cb8..b7918c5e7 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelper.java @@ -43,32 +43,21 @@ public class SqlAddHelper { } if (selectStatement instanceof PlainSelect) { PlainSelect plainSelect = (PlainSelect) selectStatement; - fields.stream() - .filter(Objects::nonNull) - .forEach( - field -> { - SelectItem selectExpressionItem = - new SelectItem(new Column(field)); - plainSelect.addSelectItems(selectExpressionItem); - }); + fields.stream().filter(Objects::nonNull).forEach(field -> { + SelectItem selectExpressionItem = new SelectItem(new Column(field)); + plainSelect.addSelectItems(selectExpressionItem); + }); } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - fields.stream() - .forEach( - field -> { - SelectItem selectExpressionItem = - new SelectItem(new Column(field)); - subPlainSelect.addSelectItems( - selectExpressionItem); - }); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + fields.stream().forEach(field -> { + SelectItem selectExpressionItem = new SelectItem(new Column(field)); + subPlainSelect.addSelectItems(selectExpressionItem); + }); + }); } } return selectStatement.toString(); @@ -88,13 +77,10 @@ public class SqlAddHelper { SetOperationList setOperationList = (SetOperationList) selectStatement.getSetOperationList(); if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - plainSelectList.add(subPlainSelect); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + }); } } @@ -238,18 +224,15 @@ public class SqlAddHelper { if (!(selectStatement instanceof PlainSelect)) { return sql; } - selectStatement.accept( - new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - addAggregateToSelectItems( - plainSelect.getSelectItems(), fieldNameToAggregate); - addAggregateToOrderByItems( - plainSelect.getOrderByElements(), fieldNameToAggregate); - addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate); - addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate); - } - }); + selectStatement.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate); + addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate); + addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate); + addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate); + } + }); return selectStatement.toString(); } @@ -276,8 +259,8 @@ public class SqlAddHelper { return selectStatement.toString(); } - private static void addAggregateToSelectItems( - List> selectItems, Map fieldNameToAggregate) { + private static void addAggregateToSelectItems(List> selectItems, + Map fieldNameToAggregate) { for (SelectItem selectItem : selectItems) { Expression expression = selectItem.getExpression(); Function function = @@ -289,8 +272,8 @@ public class SqlAddHelper { } } - private static void addAggregateToOrderByItems( - List orderByElements, Map fieldNameToAggregate) { + private static void addAggregateToOrderByItems(List orderByElements, + Map fieldNameToAggregate) { if (orderByElements == null) { return; } @@ -305,8 +288,8 @@ public class SqlAddHelper { } } - private static void addAggregateToGroupByItems( - GroupByElement groupByElement, Map fieldNameToAggregate) { + private static void addAggregateToGroupByItems(GroupByElement groupByElement, + Map fieldNameToAggregate) { if (groupByElement == null) { return; } @@ -321,16 +304,16 @@ public class SqlAddHelper { } } - private static void addAggregateToWhereItems( - Expression whereExpression, Map fieldNameToAggregate) { + private static void addAggregateToWhereItems(Expression whereExpression, + Map fieldNameToAggregate) { if (whereExpression == null) { return; } modifyWhereExpression(whereExpression, fieldNameToAggregate); } - private static void modifyWhereExpression( - Expression whereExpression, Map fieldNameToAggregate) { + private static void modifyWhereExpression(Expression whereExpression, + Map fieldNameToAggregate) { if (SqlSelectHelper.isLogicExpression(whereExpression)) { if (whereExpression instanceof AndExpression) { AndExpression andExpression = (AndExpression) whereExpression; @@ -347,15 +330,15 @@ public class SqlAddHelper { modifyWhereExpression(rightExpression, fieldNameToAggregate); } } else if (whereExpression instanceof Parenthesis) { - modifyWhereExpression( - ((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); + modifyWhereExpression(((Parenthesis) whereExpression).getExpression(), + fieldNameToAggregate); } else { setAggToFunction(whereExpression, fieldNameToAggregate); } } - private static void setAggToFunction( - Expression expression, Map fieldNameToAggregate) { + private static void setAggToFunction(Expression expression, + Map fieldNameToAggregate) { if (!(expression instanceof ComparisonOperator)) { return; } @@ -363,20 +346,16 @@ public class SqlAddHelper { if (comparisonOperator.getRightExpression() instanceof Column) { String columnName = ((Column) (comparisonOperator).getRightExpression()).getColumnName(); - Function function = - SqlSelectFunctionHelper.getFunction( - comparisonOperator.getRightExpression(), - fieldNameToAggregate.get(columnName)); + Function function = SqlSelectFunctionHelper.getFunction( + comparisonOperator.getRightExpression(), fieldNameToAggregate.get(columnName)); if (Objects.nonNull(function)) { comparisonOperator.setRightExpression(function); } } if (comparisonOperator.getLeftExpression() instanceof Column) { String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); - Function function = - SqlSelectFunctionHelper.getFunction( - comparisonOperator.getLeftExpression(), - fieldNameToAggregate.get(columnName)); + Function function = SqlSelectFunctionHelper.getFunction( + comparisonOperator.getLeftExpression(), fieldNameToAggregate.get(columnName)); if (Objects.nonNull(function)) { comparisonOperator.setLeftExpression(function); } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java index a3f7b7d99..a93639120 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelper.java @@ -27,18 +27,17 @@ public class SqlAsHelper { if (plainSelect instanceof Select) { Select select = plainSelect; Select selectBody = select.getSelectBody(); - selectBody.accept( - new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - extractAliasesFromSelect(plainSelect, aliases); - } + selectBody.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + extractAliasesFromSelect(plainSelect, aliases); + } - @Override - public void visit(WithItem withItem) { - withItem.getSelectBody().accept(this); - } - }); + @Override + public void visit(WithItem withItem) { + withItem.getSelectBody().accept(this); + } + }); } } return new ArrayList<>(aliases); diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEditEnum.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEditEnum.java index cb75229e7..594908c2c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEditEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlEditEnum.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.common.jsqlparser; public enum SqlEditEnum { - NUMBER_FILTER, - DATEDIFF + NUMBER_FILTER, DATEDIFF } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java index 0b3bbe937..ba1b5b551 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelper.java @@ -67,15 +67,14 @@ public class SqlRemoveHelper { } List> selectItems = ((PlainSelect) selectStatement).getSelectItems(); Set fields = new HashSet<>(); - selectItems.removeIf( - selectItem -> { - String field = selectItem.getExpression().toString(); - if (fields.contains(field)) { - return true; - } - fields.add(field); - return false; - }); + selectItems.removeIf(selectItem -> { + String field = selectItem.getExpression().toString(); + if (fields.contains(field)) { + return true; + } + fields.add(field); + return false; + }); ((PlainSelect) selectStatement).setSelectItems(selectItems); return selectStatement.toString(); } @@ -85,18 +84,17 @@ public class SqlRemoveHelper { if (!(selectStatement instanceof PlainSelect)) { return sql; } - selectStatement.accept( - new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - removeWhereCondition(plainSelect.getWhere(), removeFieldNames); - } - }); + selectStatement.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + removeWhereCondition(plainSelect.getWhere(), removeFieldNames); + } + }); return removeNumberFilter(selectStatement.toString()); } - private static void removeWhereCondition( - Expression whereExpression, Set removeFieldNames) { + private static void removeWhereCondition(Expression whereExpression, + Set removeFieldNames) { if (whereExpression == null) { return; } @@ -121,8 +119,8 @@ public class SqlRemoveHelper { return selectStatement.toString(); } - private static void removeWhereExpression( - Expression whereExpression, Set removeFieldNames) { + private static void removeWhereExpression(Expression whereExpression, + Set removeFieldNames) { if (SqlSelectHelper.isLogicExpression(whereExpression)) { BinaryExpression binaryExpression = (BinaryExpression) whereExpression; Expression leftExpression = binaryExpression.getLeftExpression(); @@ -131,8 +129,8 @@ public class SqlRemoveHelper { removeWhereExpression(leftExpression, removeFieldNames); removeWhereExpression(rightExpression, removeFieldNames); } else if (whereExpression instanceof Parenthesis) { - removeWhereExpression( - ((Parenthesis) whereExpression).getExpression(), removeFieldNames); + removeWhereExpression(((Parenthesis) whereExpression).getExpression(), + removeFieldNames); } else { removeExpressionWithConstant(whereExpression, removeFieldNames); } @@ -152,8 +150,8 @@ public class SqlRemoveHelper { return constant; } - private static void removeExpressionWithConstant( - Expression expression, Set removeFieldNames) { + private static void removeExpressionWithConstant(Expression expression, + Set removeFieldNames) { try { if (expression instanceof ComparisonOperator) { handleComparisonOperator((ComparisonOperator) expression, removeFieldNames); @@ -167,13 +165,10 @@ public class SqlRemoveHelper { } } - private static void handleComparisonOperator( - ComparisonOperator comparisonOperator, Set removeFieldNames) - throws JSQLParserException { - String columnName = - SqlSelectHelper.getColumnName( - comparisonOperator.getLeftExpression(), - comparisonOperator.getRightExpression()); + private static void handleComparisonOperator(ComparisonOperator comparisonOperator, + Set removeFieldNames) throws JSQLParserException { + String columnName = SqlSelectHelper.getColumnName(comparisonOperator.getLeftExpression(), + comparisonOperator.getRightExpression()); if (!removeFieldNames.contains(columnName)) { return; } @@ -185,9 +180,8 @@ public class SqlRemoveHelper { private static void handleInExpression(InExpression inExpression, Set removeFieldNames) throws JSQLParserException { - String columnName = - SqlSelectHelper.getColumnName( - inExpression.getLeftExpression(), inExpression.getRightExpression()); + String columnName = SqlSelectHelper.getColumnName(inExpression.getLeftExpression(), + inExpression.getRightExpression()); if (!removeFieldNames.contains(columnName)) { return; } @@ -196,12 +190,10 @@ public class SqlRemoveHelper { updateInExpression(inExpression, constantExpression); } - private static void handleLikeExpression( - LikeExpression likeExpression, Set removeFieldNames) - throws JSQLParserException { - String columnName = - SqlSelectHelper.getColumnName( - likeExpression.getLeftExpression(), likeExpression.getRightExpression()); + private static void handleLikeExpression(LikeExpression likeExpression, + Set removeFieldNames) throws JSQLParserException { + String columnName = SqlSelectHelper.getColumnName(likeExpression.getLeftExpression(), + likeExpression.getRightExpression()); if (!removeFieldNames.contains(columnName)) { return; } @@ -210,8 +202,8 @@ public class SqlRemoveHelper { updateLikeExpression(likeExpression, constantExpression); } - private static void updateComparisonOperator( - ComparisonOperator original, ComparisonOperator constantExpression) { + private static void updateComparisonOperator(ComparisonOperator original, + ComparisonOperator constantExpression) { original.setLeftExpression(constantExpression.getLeftExpression()); original.setRightExpression(constantExpression.getRightExpression()); original.setASTNode(constantExpression.getASTNode()); @@ -223,8 +215,8 @@ public class SqlRemoveHelper { original.setASTNode(constantExpression.getASTNode()); } - private static void updateLikeExpression( - LikeExpression original, LikeExpression constantExpression) { + private static void updateLikeExpression(LikeExpression original, + LikeExpression constantExpression) { original.setLeftExpression(constantExpression.getLeftExpression()); original.setRightExpression(constantExpression.getRightExpression()); } @@ -234,13 +226,12 @@ public class SqlRemoveHelper { if (!(selectStatement instanceof PlainSelect)) { return sql; } - selectStatement.accept( - new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - removeWhereCondition(plainSelect.getHaving(), removeFieldNames); - } - }); + selectStatement.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + removeWhereCondition(plainSelect.getHaving(), removeFieldNames); + } + }); return removeNumberFilter(selectStatement.toString()); } @@ -254,16 +245,13 @@ public class SqlRemoveHelper { return sql; } ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList(); - groupByExpressionList - .getExpressions() - .removeIf( - expression -> { - if (expression instanceof Column) { - Column column = (Column) expression; - return fields.contains(column.getColumnName()); - } - return false; - }); + groupByExpressionList.getExpressions().removeIf(expression -> { + if (expression instanceof Column) { + Column column = (Column) expression; + return fields.contains(column.getColumnName()); + } + return false; + }); if (CollectionUtils.isEmpty(groupByExpressionList.getExpressions())) { ((PlainSelect) selectStatement).setGroupByElement(null); } @@ -279,15 +267,14 @@ public class SqlRemoveHelper { Iterator> iterator = selectItems.iterator(); while (iterator.hasNext()) { SelectItem selectItem = iterator.next(); - selectItem.accept( - new SelectItemVisitorAdapter() { - @Override - public void visit(SelectItem item) { - if (fields.contains(item.getExpression().toString())) { - iterator.remove(); - } - } - }); + selectItem.accept(new SelectItemVisitorAdapter() { + @Override + public void visit(SelectItem item) { + if (fields.contains(item.getExpression().toString())) { + iterator.remove(); + } + } + }); } if (selectItems.isEmpty()) { selectItems.add(new SelectItem(new AllColumns())); @@ -345,17 +332,14 @@ public class SqlRemoveHelper { } } - private static Expression dealComparisonOperatorFilter( - Expression expression, SqlEditEnum sqlEditEnum) { + private static Expression dealComparisonOperatorFilter(Expression expression, + SqlEditEnum sqlEditEnum) { if (Objects.isNull(expression)) { return null; } - if (expression instanceof GreaterThanEquals - || expression instanceof GreaterThan - || expression instanceof MinorThan - || expression instanceof MinorThanEquals - || expression instanceof EqualsTo - || expression instanceof NotEqualsTo) { + if (expression instanceof GreaterThanEquals || expression instanceof GreaterThan + || expression instanceof MinorThan || expression instanceof MinorThanEquals + || expression instanceof EqualsTo || expression instanceof NotEqualsTo) { return removeSingleFilter((ComparisonOperator) expression, sqlEditEnum); } else if (expression instanceof InExpression) { InExpression inExpression = (InExpression) expression; @@ -369,14 +353,14 @@ public class SqlRemoveHelper { return expression; } - private static Expression removeSingleFilter( - ComparisonOperator comparisonExpression, SqlEditEnum sqlEditEnum) { + private static Expression removeSingleFilter(ComparisonOperator comparisonExpression, + SqlEditEnum sqlEditEnum) { Expression leftExpression = comparisonExpression.getLeftExpression(); return recursionBase(leftExpression, comparisonExpression, sqlEditEnum); } - private static Expression recursionBase( - Expression leftExpression, Expression expression, SqlEditEnum sqlEditEnum) { + private static Expression recursionBase(Expression leftExpression, Expression expression, + SqlEditEnum sqlEditEnum) { if (sqlEditEnum.equals(SqlEditEnum.NUMBER_FILTER)) { return distinguishNumberFilter(leftExpression, expression); } @@ -386,8 +370,8 @@ public class SqlRemoveHelper { return expression; } - private static Expression distinguishNumberFilter( - Expression leftExpression, Expression expression) { + private static Expression distinguishNumberFilter(Expression leftExpression, + Expression expression) { if (leftExpression instanceof LongValue) { return null; } else { @@ -403,8 +387,8 @@ public class SqlRemoveHelper { return removeIsNullOrNotNullInWhere(false, true, sql, removeFieldNames); } - public static String removeIsNullOrNotNullInWhere( - boolean dealNull, boolean dealNotNull, String sql, Set removeFieldNames) { + public static String removeIsNullOrNotNullInWhere(boolean dealNull, boolean dealNotNull, + String sql, Set removeFieldNames) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { return sql; diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java index 1de90cb99..32086230f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelper.java @@ -46,57 +46,46 @@ import java.util.function.UnaryOperator; /** Sql Parser replace Helper */ @Slf4j public class SqlReplaceHelper { - public static String replaceAggFields( - String sql, Map> fieldNameToAggMap) { + public static String replaceAggFields(String sql, + Map> fieldNameToAggMap) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { return sql; } - ((PlainSelect) selectStatement) - .getSelectItems().stream() - .forEach( - o -> { - SelectItem selectExpressionItem = (SelectItem) o; - if (selectExpressionItem.getExpression() instanceof Function) { - Function function = - (Function) selectExpressionItem.getExpression(); - Column column = - (Column) - function.getParameters() - .getExpressions() - .get(0); - if (fieldNameToAggMap.containsKey(column.getColumnName())) { - Pair agg = - fieldNameToAggMap.get(column.getColumnName()); - String field = agg.getKey(); - String func = agg.getRight(); - if (AggOperatorEnum.isCountDistinct(func)) { - function.setName("count"); - function.setDistinct(true); - } else { - function.setName(func); - } - function.withParameters(new Column(field)); - if (Objects.nonNull(selectExpressionItem.getAlias()) - && StringUtils.isNotBlank(field)) { - selectExpressionItem.getAlias().setName(field); - } - } - } - }); + ((PlainSelect) selectStatement).getSelectItems().stream().forEach(o -> { + SelectItem selectExpressionItem = (SelectItem) o; + if (selectExpressionItem.getExpression() instanceof Function) { + Function function = (Function) selectExpressionItem.getExpression(); + Column column = (Column) function.getParameters().getExpressions().get(0); + if (fieldNameToAggMap.containsKey(column.getColumnName())) { + Pair agg = fieldNameToAggMap.get(column.getColumnName()); + String field = agg.getKey(); + String func = agg.getRight(); + if (AggOperatorEnum.isCountDistinct(func)) { + function.setName("count"); + function.setDistinct(true); + } else { + function.setName(func); + } + function.withParameters(new Column(field)); + if (Objects.nonNull(selectExpressionItem.getAlias()) + && StringUtils.isNotBlank(field)) { + selectExpressionItem.getAlias().setName(field); + } + } + } + }); return selectStatement.toString(); } - public static String replaceValue( - String sql, Map> filedNameToValueMap) { + public static String replaceValue(String sql, + Map> filedNameToValueMap) { return replaceValue(sql, filedNameToValueMap, true); } - public static String replaceValue( - String sql, - Map> filedNameToValueMap, - boolean exactReplace) { + public static String replaceValue(String sql, + Map> filedNameToValueMap, boolean exactReplace) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { return sql; @@ -113,8 +102,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } - public static String replaceFieldNameByValue( - String sql, Map> fieldValueToFieldNames) { + public static String replaceFieldNameByValue(String sql, + Map> fieldValueToFieldNames) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { return sql; @@ -145,14 +134,11 @@ public class SqlReplaceHelper { } else if (select instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) select; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - plainSelectList.add(subPlainSelect); - getFromSelect(subPlainSelect.getFromItem(), plainSelectList); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + getFromSelect(subPlainSelect.getFromItem(), plainSelectList); + }); } } } @@ -161,8 +147,8 @@ public class SqlReplaceHelper { return replaceFields(sql, fieldNameMap, false); } - public static String replaceFields( - String sql, Map fieldNameMap, boolean exactReplace) { + public static String replaceFields(String sql, Map fieldNameMap, + boolean exactReplace) { Select selectStatement = SqlSelectHelper.getSelect(sql); List plainSelectList = SqlSelectHelper.getWithItem(selectStatement); if (selectStatement instanceof PlainSelect) { @@ -172,14 +158,11 @@ public class SqlReplaceHelper { } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - plainSelectList.add(subPlainSelect); - getFromSelect(subPlainSelect.getFromItem(), plainSelectList); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + getFromSelect(subPlainSelect.getFromItem(), plainSelectList); + }); } List orderByElements = setOperationList.getOrderByElements(); if (!CollectionUtils.isEmpty(orderByElements)) { @@ -197,8 +180,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } - private static void replaceFieldsInPlainOneSelect( - Map fieldNameMap, boolean exactReplace, PlainSelect plainSelect) { + private static void replaceFieldsInPlainOneSelect(Map fieldNameMap, + boolean exactReplace, PlainSelect plainSelect) { // 1. replace where fields Expression where = plainSelect.getWhere(); FieldReplaceVisitor visitor = new FieldReplaceVisitor(fieldNameMap, exactReplace); @@ -220,14 +203,10 @@ public class SqlReplaceHelper { } else if (select instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) select; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - replaceFieldsInPlainOneSelect( - fieldNameMap, exactReplace, subPlainSelect); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + replaceFieldsInPlainOneSelect(fieldNameMap, exactReplace, subPlainSelect); + }); } } } @@ -253,11 +232,9 @@ public class SqlReplaceHelper { if (!CollectionUtils.isEmpty(joins)) { for (Join join : joins) { if (!CollectionUtils.isEmpty(join.getOnExpressions())) { - join.getOnExpressions().stream() - .forEach( - onExpression -> { - onExpression.accept(visitor); - }); + join.getOnExpressions().stream().forEach(onExpression -> { + onExpression.accept(visitor); + }); } if (!(join.getRightItem() instanceof ParenthesedSelect)) { continue; @@ -278,8 +255,8 @@ public class SqlReplaceHelper { return replaceFunction(sql, functionMap, null); } - public static String replaceFunction( - String sql, Map functionMap, Map functionCall) { + public static String replaceFunction(String sql, Map functionMap, + Map functionCall) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { return sql; @@ -293,10 +270,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } - private static void replaceFunction( - Map functionMap, - Map functionCall, - PlainSelect selectBody) { + private static void replaceFunction(Map functionMap, + Map functionCall, PlainSelect selectBody) { PlainSelect plainSelect = selectBody; // 1. replace where dataDiff function Expression where = plainSelect.getWhere(); @@ -356,8 +331,8 @@ public class SqlReplaceHelper { } } - private static void replaceComparisonOperatorFunction( - Map functionMap, Expression expression) { + private static void replaceComparisonOperatorFunction(Map functionMap, + Expression expression) { if (Objects.isNull(expression)) { return; } @@ -376,8 +351,8 @@ public class SqlReplaceHelper { } } - private static void replaceOrderByFunction( - Map functionMap, List orderByElementList) { + private static void replaceOrderByFunction(Map functionMap, + List orderByElementList) { if (Objects.isNull(orderByElementList)) { return; } @@ -410,25 +385,23 @@ public class SqlReplaceHelper { List plainSelectList = SqlSelectHelper.getWithItem(selectStatement); if (!CollectionUtils.isEmpty(plainSelectList)) { List withNameList = SqlSelectHelper.getWithName(sql); - plainSelectList.stream() - .forEach( - plainSelect -> { - if (plainSelect.getFromItem() instanceof Table) { - Table table = (Table) plainSelect.getFromItem(); - if (!withNameList.contains(table.getName())) { - replaceSingleTable(plainSelect, tableName); - } - } - if (plainSelect.getFromItem() instanceof ParenthesedSelect) { - ParenthesedSelect parenthesedSelect = - (ParenthesedSelect) plainSelect.getFromItem(); - PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect(); - Table table = (Table) subPlainSelect.getFromItem(); - if (!withNameList.contains(table.getName())) { - replaceSingleTable(subPlainSelect, tableName); - } - } - }); + plainSelectList.stream().forEach(plainSelect -> { + if (plainSelect.getFromItem() instanceof Table) { + Table table = (Table) plainSelect.getFromItem(); + if (!withNameList.contains(table.getName())) { + replaceSingleTable(plainSelect, tableName); + } + } + if (plainSelect.getFromItem() instanceof ParenthesedSelect) { + ParenthesedSelect parenthesedSelect = + (ParenthesedSelect) plainSelect.getFromItem(); + PlainSelect subPlainSelect = parenthesedSelect.getPlainSelect(); + Table table = (Table) subPlainSelect.getFromItem(); + if (!withNameList.contains(table.getName())) { + replaceSingleTable(subPlainSelect, tableName); + } + } + }); return selectStatement.toString(); } if (selectStatement instanceof PlainSelect) { @@ -438,14 +411,11 @@ public class SqlReplaceHelper { } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - replaceSingleTable(subPlainSelect, tableName); - replaceSubTable(subPlainSelect, tableName); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + replaceSingleTable(subPlainSelect, tableName); + replaceSubTable(subPlainSelect, tableName); + }); } } @@ -476,15 +446,12 @@ public class SqlReplaceHelper { plainSelects.add(plainSelect); List painSelects = SqlSelectHelper.getPlainSelects(plainSelects); for (PlainSelect painSelect : painSelects) { - painSelect.accept( - new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - plainSelect - .getFromItem() - .accept(new TableNameReplaceVisitor(tableName)); - } - }); + painSelect.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + plainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); + } + }); List joins = painSelect.getJoins(); if (!CollectionUtils.isEmpty(joins)) { for (Join join : joins) { @@ -494,8 +461,7 @@ public class SqlReplaceHelper { List subPlainSelects = SqlSelectHelper.getPlainSelects(plainSelectList); for (PlainSelect subPlainSelect : subPlainSelects) { - subPlainSelect - .getFromItem() + subPlainSelect.getFromItem() .accept(new TableNameReplaceVisitor(tableName)); } } else if (join.getRightItem() instanceof Table) { @@ -524,8 +490,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } - public static String replaceHavingValue( - String sql, Map> filedNameToValueMap) { + public static String replaceHavingValue(String sql, + Map> filedNameToValueMap) { Select selectStatement = SqlSelectHelper.getSelect(sql); if (!(selectStatement instanceof PlainSelect)) { return sql; @@ -539,8 +505,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } - public static Expression distinguishDateDiffFilter( - Expression leftExpression, Expression expression) { + public static Expression distinguishDateDiffFilter(Expression leftExpression, + Expression expression) { if (leftExpression instanceof Function) { Function function = (Function) leftExpression; if (function.getName().equals(JsqlConstants.DATE_FUNCTION)) { @@ -558,17 +524,14 @@ public class SqlReplaceHelper { String endDateCondExpr = columnName + endDateOperator + StringUtil.getCommaWrap(endDateValue); - ComparisonOperator rightExpression = - (ComparisonOperator) - CCJSqlParserUtil.parseCondExpression(endDateCondExpr); + ComparisonOperator rightExpression = (ComparisonOperator) CCJSqlParserUtil + .parseCondExpression(endDateCondExpr); String startDateCondExpr = - columnName - + StringUtil.getSpaceWrap(startDateOperator) + columnName + StringUtil.getSpaceWrap(startDateOperator) + StringUtil.getCommaWrap(startDateValue); - ComparisonOperator newLeftExpression = - (ComparisonOperator) - CCJSqlParserUtil.parseCondExpression(startDateCondExpr); + ComparisonOperator newLeftExpression = (ComparisonOperator) CCJSqlParserUtil + .parseCondExpression(startDateCondExpr); AndExpression andExpression = new AndExpression(newLeftExpression, rightExpression); @@ -576,8 +539,8 @@ public class SqlReplaceHelper { || JsqlConstants.GREATER_THAN_EQUALS.equals(dateOperator)) { return newLeftExpression; } else { - return CCJSqlParserUtil.parseCondExpression( - "(" + andExpression.toString() + ")"); + return CCJSqlParserUtil + .parseCondExpression("(" + andExpression.toString() + ")"); } } catch (JSQLParserException e) { log.error("JSQLParserException", e); @@ -608,30 +571,24 @@ public class SqlReplaceHelper { } } } - plainSelect.getOrderByElements().stream() - .forEach( - o -> { - if (o.getExpression() instanceof Function) { - Function function = (Function) o.getExpression(); - if (function.getParameters().size() == 1 - && function.getParameters().get(0) - instanceof Column) { - Column column = - (Column) function.getParameters().get(0); - if (selectNames.containsKey(column.getColumnName())) { - o.setExpression( - new LongValue( - selectNames.get( - column.getColumnName()))); - } - } - } - }); + plainSelect.getOrderByElements().stream().forEach(o -> { + if (o.getExpression() instanceof Function) { + Function function = (Function) o.getExpression(); + if (function.getParameters().size() == 1 + && function.getParameters().get(0) instanceof Column) { + Column column = (Column) function.getParameters().get(0); + if (selectNames.containsKey(column.getColumnName())) { + o.setExpression( + new LongValue(selectNames.get(column.getColumnName()))); + } + } + } + }); } if (plainSelect.getFromItem() instanceof ParenthesedSelect) { ParenthesedSelect parenthesedSelect = (ParenthesedSelect) plainSelect.getFromItem(); - parenthesedSelect.setSelect( - replaceAggAliasOrderItem(parenthesedSelect.getSelect())); + parenthesedSelect + .setSelect(replaceAggAliasOrderItem(parenthesedSelect.getSelect())); } return selectStatement; } @@ -665,13 +622,10 @@ public class SqlReplaceHelper { } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - plainSelectList.add(subPlainSelect); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + plainSelectList.add(subPlainSelect); + }); } } else { return sql; @@ -683,8 +637,8 @@ public class SqlReplaceHelper { return selectStatement.toString(); } - private static void replacePlainSelectByExpr( - PlainSelect plainSelect, Map replace) { + private static void replacePlainSelectByExpr(PlainSelect plainSelect, + Map replace) { QueryExpressionReplaceVisitor expressionReplaceVisitor = new QueryExpressionReplaceVisitor(replace); for (SelectItem selectItem : plainSelect.getSelectItems()) { @@ -703,9 +657,8 @@ public class SqlReplaceHelper { List orderByElements = plainSelect.getOrderByElements(); if (!CollectionUtils.isEmpty(orderByElements)) { for (OrderByElement orderByElement : orderByElements) { - orderByElement.setExpression( - QueryExpressionReplaceVisitor.replace( - orderByElement.getExpression(), replace)); + orderByElement.setExpression(QueryExpressionReplaceVisitor + .replace(orderByElement.getExpression(), replace)); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelper.java index 019063ab5..e9e7a4992 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelper.java @@ -58,8 +58,8 @@ public class SqlSelectFunctionHelper { return visitor.getFunctionNames(); } - public static Function getFunction( - Expression expression, Map fieldNameToAggregate) { + public static Function getFunction(Expression expression, + Map fieldNameToAggregate) { if (!(expression instanceof Column)) { return null; } @@ -100,8 +100,7 @@ public class SqlSelectFunctionHelper { FunctionVisitor visitor = new FunctionVisitor(); expression.accept(visitor); Set functions = visitor.getFunctionNames(); - return functions.stream() - .filter(t -> aggregateFunctionName.contains(t.toUpperCase())) + return functions.stream().filter(t -> aggregateFunctionName.contains(t.toUpperCase())) .collect(Collectors.toList()); } return new ArrayList<>(); diff --git a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java index cc5305ea7..5ea832fbb 100644 --- a/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelper.java @@ -70,12 +70,9 @@ public class SqlSelectHelper { having.accept(new FieldAndValueAcquireVisitor(result)); } } - result = - result.stream() - .filter( - fieldExpression -> - StringUtils.isNotBlank(fieldExpression.getFieldName())) - .collect(Collectors.toSet()); + result = result.stream() + .filter(fieldExpression -> StringUtils.isNotBlank(fieldExpression.getFieldName())) + .collect(Collectors.toSet()); return new ArrayList<>(result); } @@ -90,31 +87,27 @@ public class SqlSelectHelper { } public static void getWhereFields(List plainSelectList, Set result) { - plainSelectList.stream() - .forEach( - plainSelect -> { - Expression where = plainSelect.getWhere(); - if (Objects.nonNull(where)) { - where.accept(new FieldAcquireVisitor(result)); - } - }); + plainSelectList.stream().forEach(plainSelect -> { + Expression where = plainSelect.getWhere(); + if (Objects.nonNull(where)) { + where.accept(new FieldAcquireVisitor(result)); + } + }); } public static List gePureSelectFields(String sql) { List plainSelectList = getPlainSelect(sql); Set result = new HashSet<>(); - plainSelectList.stream() - .forEach( - plainSelect -> { - List> selectItems = plainSelect.getSelectItems(); - for (SelectItem selectItem : selectItems) { - if (!(selectItem.getExpression() instanceof Column)) { - continue; - } - Column column = (Column) selectItem.getExpression(); - result.add(column.getColumnName()); - } - }); + plainSelectList.stream().forEach(plainSelect -> { + List> selectItems = plainSelect.getSelectItems(); + for (SelectItem selectItem : selectItems) { + if (!(selectItem.getExpression() instanceof Column)) { + continue; + } + Column column = (Column) selectItem.getExpression(); + result.add(column.getColumnName()); + } + }); return new ArrayList<>(result); } @@ -128,14 +121,12 @@ public class SqlSelectHelper { public static Set getSelectFields(List plainSelectList) { Set result = new HashSet<>(); - plainSelectList.stream() - .forEach( - plainSelect -> { - List> selectItems = plainSelect.getSelectItems(); - for (SelectItem selectItem : selectItems) { - selectItem.accept(new FieldAcquireVisitor(result)); - } - }); + plainSelectList.stream().forEach(plainSelect -> { + List> selectItems = plainSelect.getSelectItems(); + for (SelectItem selectItem : selectItems) { + selectItem.accept(new FieldAcquireVisitor(result)); + } + }); return result; } @@ -152,13 +143,10 @@ public class SqlSelectHelper { } else if (selectStatement instanceof SetOperationList) { SetOperationList setOperationList = (SetOperationList) selectStatement; if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { - setOperationList - .getSelects() - .forEach( - subSelectBody -> { - PlainSelect subPlainSelect = (PlainSelect) subSelectBody; - getSubPlainSelect(subPlainSelect, plainSelectList); - }); + setOperationList.getSelects().forEach(subSelectBody -> { + PlainSelect subPlainSelect = (PlainSelect) subSelectBody; + getSubPlainSelect(subPlainSelect, plainSelectList); + }); } } return plainSelectList; @@ -235,39 +223,37 @@ public class SqlSelectHelper { List plainSelects = new ArrayList<>(); for (PlainSelect plainSelect : plainSelectList) { plainSelects.add(plainSelect); - ExpressionVisitorAdapter expressionVisitor = - new ExpressionVisitorAdapter() { - @Override - public void visit(Select subSelect) { - if (subSelect instanceof ParenthesedSelect) { - ParenthesedSelect parenthesedSelect = (ParenthesedSelect) subSelect; - if (parenthesedSelect.getSelect() instanceof PlainSelect) { - plainSelects.add(parenthesedSelect.getPlainSelect()); - } - } + ExpressionVisitorAdapter expressionVisitor = new ExpressionVisitorAdapter() { + @Override + public void visit(Select subSelect) { + if (subSelect instanceof ParenthesedSelect) { + ParenthesedSelect parenthesedSelect = (ParenthesedSelect) subSelect; + if (parenthesedSelect.getSelect() instanceof PlainSelect) { + plainSelects.add(parenthesedSelect.getPlainSelect()); } - }; + } + } + }; - plainSelect.accept( - new SelectVisitorAdapter() { - @Override - public void visit(PlainSelect plainSelect) { - Expression whereExpression = plainSelect.getWhere(); - if (whereExpression != null) { - whereExpression.accept(expressionVisitor); - } - Expression having = plainSelect.getHaving(); - if (Objects.nonNull(having)) { - having.accept(expressionVisitor); - } - List> selectItems = plainSelect.getSelectItems(); - if (!CollectionUtils.isEmpty(selectItems)) { - for (SelectItem selectItem : selectItems) { - selectItem.accept(expressionVisitor); - } - } + plainSelect.accept(new SelectVisitorAdapter() { + @Override + public void visit(PlainSelect plainSelect) { + Expression whereExpression = plainSelect.getWhere(); + if (whereExpression != null) { + whereExpression.accept(expressionVisitor); + } + Expression having = plainSelect.getHaving(); + if (Objects.nonNull(having)) { + having.accept(expressionVisitor); + } + List> selectItems = plainSelect.getSelectItems(); + if (!CollectionUtils.isEmpty(selectItems)) { + for (SelectItem selectItem : selectItems) { + selectItem.accept(expressionVisitor); } - }); + } + } + }); } return plainSelects; } @@ -313,14 +299,11 @@ public class SqlSelectHelper { private static void getLateralViewsFields(PlainSelect plainSelect, Set result) { List lateralViews = plainSelect.getLateralViews(); if (!CollectionUtils.isEmpty(lateralViews)) { - lateralViews.stream() - .forEach( - l -> { - if (Objects.nonNull(l.getGeneratorFunction())) { - l.getGeneratorFunction() - .accept(new FieldAcquireVisitor(result)); - } - }); + lateralViews.stream().forEach(l -> { + if (Objects.nonNull(l.getGeneratorFunction())) { + l.getGeneratorFunction().accept(new FieldAcquireVisitor(result)); + } + }); } } @@ -425,11 +408,9 @@ public class SqlSelectHelper { private static void getOrderByFields(PlainSelect plainSelect, Set result) { Set orderByFieldExpressions = getOrderByFields(plainSelect); - Set collect = - orderByFieldExpressions.stream() - .map(fieldExpression -> fieldExpression.getFieldName()) - .filter(Objects::nonNull) - .collect(Collectors.toSet()); + Set collect = orderByFieldExpressions.stream() + .map(fieldExpression -> fieldExpression.getFieldName()).filter(Objects::nonNull) + .collect(Collectors.toSet()); result.addAll(collect); } @@ -487,9 +468,8 @@ public class SqlSelectHelper { if (selectItem.getExpression() instanceof Function) { Function function = (Function) selectItem.getExpression(); - if (Objects.nonNull(function.getParameters()) - && !CollectionUtils.isEmpty( - function.getParameters().getExpressions())) { + if (Objects.nonNull(function.getParameters()) && !CollectionUtils + .isEmpty(function.getParameters().getExpressions())) { String columnName = function.getParameters().getExpressions().get(0).toString(); result.add(columnName); @@ -516,9 +496,8 @@ public class SqlSelectHelper { if (alias != null && StringUtils.isNotBlank(alias.getName())) { result.add(alias.getName()); } else { - if (Objects.nonNull(function.getParameters()) - && !CollectionUtils.isEmpty( - function.getParameters().getExpressions())) { + if (Objects.nonNull(function.getParameters()) && !CollectionUtils + .isEmpty(function.getParameters().getExpressions())) { String columnName = function.getParameters().getExpressions().get(0).toString(); result.add(columnName); @@ -552,9 +531,8 @@ public class SqlSelectHelper { } public static boolean isLogicExpression(Expression whereExpression) { - return whereExpression instanceof AndExpression - || (whereExpression instanceof OrExpression - || (whereExpression instanceof XorExpression)); + return whereExpression instanceof AndExpression || (whereExpression instanceof OrExpression + || (whereExpression instanceof XorExpression)); } public static String getColumnName(Expression leftExpression, Expression rightExpression) { @@ -789,8 +767,8 @@ public class SqlSelectHelper { return results; } - private static void getFieldsWithSubQuery( - PlainSelect plainSelect, Map> fields) { + private static void getFieldsWithSubQuery(PlainSelect plainSelect, + Map> fields) { if (plainSelect.getFromItem() instanceof Table) { List withAlias = new ArrayList<>(); if (!CollectionUtils.isEmpty(plainSelect.getWithItemsList())) { @@ -807,10 +785,8 @@ public class SqlSelectHelper { if (!fields.containsKey(table.getFullyQualifiedName())) { fields.put(tableName, new HashSet<>()); } - List sqlFields = - getFieldsByPlainSelect(plainSelect).stream() - .map(f -> f.replaceAll("`", "")) - .collect(Collectors.toList()); + List sqlFields = getFieldsByPlainSelect(plainSelect).stream() + .map(f -> f.replaceAll("`", "")).collect(Collectors.toList()); fields.get(tableName).addAll(sqlFields); } } @@ -826,8 +802,8 @@ public class SqlSelectHelper { ((ParenthesedSelect) join.getRightItem()).getPlainSelect(), fields); } if (join.getFromItem() instanceof ParenthesedSelect) { - getFieldsWithSubQuery( - ((ParenthesedSelect) join.getFromItem()).getPlainSelect(), fields); + getFieldsWithSubQuery(((ParenthesedSelect) join.getFromItem()).getPlainSelect(), + fields); } } } diff --git a/common/src/main/java/com/tencent/supersonic/common/persistence/mapper/SystemConfigMapper.java b/common/src/main/java/com/tencent/supersonic/common/persistence/mapper/SystemConfigMapper.java index bcc0e07db..3abff6924 100644 --- a/common/src/main/java/com/tencent/supersonic/common/persistence/mapper/SystemConfigMapper.java +++ b/common/src/main/java/com/tencent/supersonic/common/persistence/mapper/SystemConfigMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.common.persistence.dataobject.SystemConfigDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface SystemConfigMapper extends BaseMapper {} +public interface SystemConfigMapper extends BaseMapper { +} diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Criterion.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Criterion.java index 4240b4079..1b1dba5f6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Criterion.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Criterion.java @@ -35,23 +35,15 @@ public class Criterion { public boolean isNeedApostrophe() { return Arrays.stream(StringDataType.values()) - .filter(value -> this.dataType.equalsIgnoreCase(value.getType())) - .findFirst() + .filter(value -> this.dataType.equalsIgnoreCase(value.getType())).findFirst() .isPresent(); } public enum NumericDataType { - TINYINT("TINYINT"), - SMALLINT("SMALLINT"), - MEDIUMINT("MEDIUMINT"), - INT("INT"), - INTEGER("INTEGER"), - BIGINT("BIGINT"), - FLOAT("FLOAT"), - DOUBLE("DOUBLE"), - DECIMAL("DECIMAL"), - NUMERIC("NUMERIC"), - ; + TINYINT("TINYINT"), SMALLINT("SMALLINT"), MEDIUMINT("MEDIUMINT"), INT("INT"), INTEGER( + "INTEGER"), BIGINT("BIGINT"), FLOAT( + "FLOAT"), DOUBLE("DOUBLE"), DECIMAL("DECIMAL"), NUMERIC("NUMERIC"),; + private String type; NumericDataType(String type) { @@ -64,9 +56,8 @@ public class Criterion { } public enum StringDataType { - VARCHAR("VARCHAR"), - STRING("STRING"), - ; + VARCHAR("VARCHAR"), STRING("STRING"),; + private String type; StringDataType(String type) { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/DataUpdateEvent.java b/common/src/main/java/com/tencent/supersonic/common/pojo/DataUpdateEvent.java index 0fe9c28de..2e433f9d3 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/DataUpdateEvent.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/DataUpdateEvent.java @@ -11,8 +11,8 @@ public class DataUpdateEvent extends ApplicationEvent { private Long id; private TypeEnums type; - public DataUpdateEvent( - Object source, String name, String newName, Long modelId, Long id, TypeEnums type) { + public DataUpdateEvent(Object source, String name, String newName, Long modelId, Long id, + TypeEnums type) { super(source); this.name = name; this.newName = newName; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/DateConf.java b/common/src/main/java/com/tencent/supersonic/common/pojo/DateConf.java index 784753cf6..f162395bf 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/DateConf.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/DateConf.java @@ -70,10 +70,8 @@ public class DateConf { return false; } DateConf dateConf = (DateConf) o; - return dateMode == dateConf.dateMode - && Objects.equals(startDate, dateConf.startDate) - && Objects.equals(endDate, dateConf.endDate) - && Objects.equals(unit, dateConf.unit) + return dateMode == dateConf.dateMode && Objects.equals(startDate, dateConf.startDate) + && Objects.equals(endDate, dateConf.endDate) && Objects.equals(unit, dateConf.unit) && Objects.equals(period, dateConf.period); } @@ -89,11 +87,7 @@ public class DateConf { * the element, [unit, period] 4 - AVAILABLE, dynamic time which guaranteed to query some * data, [startDate, endDate] 5 - ALL, all table data */ - BETWEEN, - LIST, - RECENT, - AVAILABLE, - ALL + BETWEEN, LIST, RECENT, AVAILABLE, ALL } @Override diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Filter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Filter.java index 6f79368f1..9ecb491ec 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Filter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Filter.java @@ -47,8 +47,6 @@ public class Filter { } public enum Relation { - FILTER, - OR, - AND + FILTER, OR, AND } } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java b/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java index 67c4e6323..6861af708 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/Parameter.java @@ -13,22 +13,28 @@ import java.util.Map; /** * 1.Password Field: * - *

dataType: string name: password require: true/false or any value/empty placeholder: 'Please - * enter the relevant configuration information' value: initial value Text Input Field: + *

+ * dataType: string name: password require: true/false or any value/empty placeholder: 'Please enter + * the relevant configuration information' value: initial value Text Input Field: * - *

2.dataType: string require: true/false or any value/empty placeholder: 'Please enter the - * relevant configuration information' value: initial value Long Text Input Field: + *

+ * 2.dataType: string require: true/false or any value/empty placeholder: 'Please enter the relevant + * configuration information' value: initial value Long Text Input Field: * - *

3.dataType: longText require: true/false or any value/empty placeholder: 'Please enter the + *

+ * 3.dataType: longText require: true/false or any value/empty placeholder: 'Please enter the * relevant configuration information' value: initial value Number Input Field: * - *

4.dataType: number require: true/false or any value/empty placeholder: 'Please enter the - * relevant configuration information' value: initial value Switch Component: + *

+ * 4.dataType: number require: true/false or any value/empty placeholder: 'Please enter the relevant + * configuration information' value: initial value Switch Component: * - *

5.dataType: bool require: true/false or any value/empty value: initial value Select Dropdown + *

+ * 5.dataType: bool require: true/false or any value/empty value: initial value Select Dropdown * Component: * - *

6.dataType: list candidateValues: ["OPEN_AI", "OLLAMA"] or [{label: 'Model Name 1', value: + *

+ * 6.dataType: list candidateValues: ["OPEN_AI", "OLLAMA"] or [{label: 'Model Name 1', value: * 'OPEN_AI'}, {label: 'Model Name 2', value: 'OLLAMA'}] require: true/false or any value/empty * placeholder: 'Please enter the relevant configuration information' value: initial value */ @@ -43,35 +49,18 @@ public class Parameter { private List candidateValues; private List dependencies; - public Parameter( - String name, - String defaultValue, - String comment, - String description, - String dataType, - String module) { + public Parameter(String name, String defaultValue, String comment, String description, + String dataType, String module) { this(name, defaultValue, comment, description, dataType, module, null, null); } - public Parameter( - String name, - String defaultValue, - String comment, - String description, - String dataType, - String module, - List candidateValues) { + public Parameter(String name, String defaultValue, String comment, String description, + String dataType, String module, List candidateValues) { this(name, defaultValue, comment, description, dataType, module, candidateValues, null); } - public Parameter( - String name, - String defaultValue, - String comment, - String description, - String dataType, - String module, - List candidateValues, + public Parameter(String name, String defaultValue, String comment, String description, + String dataType, String module, List candidateValues, List dependencies) { this.name = name; this.defaultValue = defaultValue; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java index 91e23b6a8..5271c34c8 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java @@ -9,16 +9,13 @@ public enum AggOperatorEnum { SUM("SUM"), - COUNT("COUNT"), - COUNT_DISTINCT("COUNT_DISTINCT"), - DISTINCT("DISTINCT"), + COUNT("COUNT"), COUNT_DISTINCT("COUNT_DISTINCT"), DISTINCT("DISTINCT"), TOPN("TOPN"), PERCENTILE("PERCENTILE"), - RATIO_ROLL("RATIO_ROLL"), - RATIO_OVER("RATIO_OVER"), + RATIO_ROLL("RATIO_ROLL"), RATIO_OVER("RATIO_OVER"), UNKNOWN("UNKNOWN"); diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggregateTypeEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggregateTypeEnum.java index 137c67124..ad50f46a2 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggregateTypeEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggregateTypeEnum.java @@ -1,14 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; public enum AggregateTypeEnum { - SUM, - AVG, - MAX, - MIN, - TOPN, - DISTINCT, - COUNT, - NONE; + SUM, AVG, MAX, MIN, TOPN, DISTINCT, COUNT, NONE; public static AggregateTypeEnum of(String agg) { for (AggregateTypeEnum aggEnum : AggregateTypeEnum.values()) { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ApiItemType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ApiItemType.java index c448fd518..4942f195b 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ApiItemType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ApiItemType.java @@ -1,7 +1,5 @@ package com.tencent.supersonic.common.pojo.enums; public enum ApiItemType { - METRIC, - TAG, - DIMENSION + METRIC, TAG, DIMENSION } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java index a049a4c77..0c505a120 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AuthType.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.common.pojo.enums; public enum AuthType { - VISIBLE, - ADMIN + VISIBLE, ADMIN } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ConfigMode.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ConfigMode.java index ca2a4f86a..501334d29 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ConfigMode.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ConfigMode.java @@ -1,9 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; public enum ConfigMode { - DETAIL("DETAIL"), - AGG("AGG"), - UNKNOWN("UNKNOWN"); + DETAIL("DETAIL"), AGG("AGG"), UNKNOWN("UNKNOWN"); private String mode; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DatePeriodEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DatePeriodEnum.java index 4f2197e73..1e2e33c47 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DatePeriodEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DatePeriodEnum.java @@ -1,11 +1,8 @@ package com.tencent.supersonic.common.pojo.enums; public enum DatePeriodEnum { - DAY("天"), - WEEK("周"), - MONTH("月"), - QUARTER("季度"), - YEAR("年"); + DAY("天"), WEEK("周"), MONTH("月"), QUARTER("季度"), YEAR("年"); + private String chName; DatePeriodEnum(String chName) { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java index 28b5b50ea..5259131d4 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/DictWordType.java @@ -51,8 +51,7 @@ public enum DictWordType { return DATASET; } // dimension value - if (natures.length == 3 - && StringUtils.isNumeric(natures[1]) + if (natures.length == 3 && StringUtils.isNumeric(natures[1]) && StringUtils.isNumeric(natures[2])) { return VALUE; } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EngineType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EngineType.java index 4b05d2620..8d5d5f931 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EngineType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EngineType.java @@ -1,14 +1,8 @@ package com.tencent.supersonic.common.pojo.enums; public enum EngineType { - TDW(0, "tdw"), - MYSQL(1, "mysql"), - DORIS(2, "doris"), - CLICKHOUSE(3, "clickhouse"), - KAFKA(4, "kafka"), - H2(5, "h2"), - POSTGRESQL(6, "postgresql"), - OTHER(7, "other"); + TDW(0, "tdw"), MYSQL(1, "mysql"), DORIS(2, "doris"), CLICKHOUSE(3, "clickhouse"), KAFKA(4, + "kafka"), H2(5, "h2"), POSTGRESQL(6, "postgresql"), OTHER(7, "other"); private Integer code; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EventType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EventType.java index fb4968c68..5d6448576 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EventType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/EventType.java @@ -1,7 +1,5 @@ package com.tencent.supersonic.common.pojo.enums; public enum EventType { - ADD, - UPDATE, - DELETE + ADD, UPDATE, DELETE } diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java index d54e4996d..0c21a3761 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/FilterOperatorEnum.java @@ -10,20 +10,10 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThan; import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; public enum FilterOperatorEnum { - IN("IN"), - NOT_IN("NOT_IN"), - EQUALS("="), - BETWEEN("BETWEEN"), - GREATER_THAN(">"), - GREATER_THAN_EQUALS(">="), - IS_NULL("IS_NULL"), - IS_NOT_NULL("IS_NOT_NULL"), - LIKE("LIKE"), - MINOR_THAN("<"), - MINOR_THAN_EQUALS("<="), - NOT_EQUALS("!="), - SQL_PART("SQL_PART"), - EXISTS("EXISTS"); + IN("IN"), NOT_IN("NOT_IN"), EQUALS("="), BETWEEN("BETWEEN"), GREATER_THAN( + ">"), GREATER_THAN_EQUALS(">="), IS_NULL("IS_NULL"), IS_NOT_NULL("IS_NOT_NULL"), LIKE( + "LIKE"), MINOR_THAN("<"), MINOR_THAN_EQUALS( + "<="), NOT_EQUALS("!="), SQL_PART("SQL_PART"), EXISTS("EXISTS"); private String value; @@ -48,8 +38,7 @@ public enum FilterOperatorEnum { } public static boolean isValueCompare(FilterOperatorEnum filterOperatorEnum) { - return EQUALS.equals(filterOperatorEnum) - || GREATER_THAN.equals(filterOperatorEnum) + return EQUALS.equals(filterOperatorEnum) || GREATER_THAN.equals(filterOperatorEnum) || GREATER_THAN_EQUALS.equals(filterOperatorEnum) || MINOR_THAN.equals(filterOperatorEnum) || MINOR_THAN_EQUALS.equals(filterOperatorEnum) diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/PublishEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/PublishEnum.java index f5777c30f..a1cec8900 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/PublishEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/PublishEnum.java @@ -1,8 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; public enum PublishEnum { - UN_PUBLISHED(0), - PUBLISHED(1); + UN_PUBLISHED(0), PUBLISHED(1); private Integer code; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/RatioOverType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/RatioOverType.java index 5083aaebe..1df9bf2a4 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/RatioOverType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/RatioOverType.java @@ -1,13 +1,8 @@ package com.tencent.supersonic.common.pojo.enums; public enum RatioOverType { - DAY_ON_DAY("日环比"), - WEEK_ON_DAY("周环比"), - WEEK_ON_WEEK("周环比"), - MONTH_ON_WEEK("月环比"), - MONTH_ON_MONTH("月环比"), - YEAR_ON_MONTH("年同比"), - YEAR_ON_YEAR("年环比"); + DAY_ON_DAY("日环比"), WEEK_ON_DAY("周环比"), WEEK_ON_WEEK("周环比"), MONTH_ON_WEEK( + "月环比"), MONTH_ON_MONTH("月环比"), YEAR_ON_MONTH("年同比"), YEAR_ON_YEAR("年环比"); private String showName; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ReturnCode.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ReturnCode.java index ce56f6279..9e1e11db1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ReturnCode.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/ReturnCode.java @@ -1,11 +1,9 @@ package com.tencent.supersonic.common.pojo.enums; public enum ReturnCode { - SUCCESS(200, "success"), - INVALID_REQUEST(400, "invalid request"), - INVALID_PERMISSION(401, "invalid permission"), - ACCESS_ERROR(403, "access denied"), - SYSTEM_ERROR(500, "system error"); + SUCCESS(200, "success"), INVALID_REQUEST(400, "invalid request"), INVALID_PERMISSION(401, + "invalid permission"), ACCESS_ERROR(403, + "access denied"), SYSTEM_ERROR(500, "system error"); private final int code; private final String message; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/SensitiveLevelEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/SensitiveLevelEnum.java index 1ecb5bd05..610d559bc 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/SensitiveLevelEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/SensitiveLevelEnum.java @@ -1,9 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; public enum SensitiveLevelEnum { - LOW(0), - MID(1), - HIGH(2); + LOW(0), MID(1), HIGH(2); private Integer code; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/StatusEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/StatusEnum.java index 168e1eda9..e0b948d03 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/StatusEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/StatusEnum.java @@ -1,12 +1,8 @@ package com.tencent.supersonic.common.pojo.enums; public enum StatusEnum { - INITIALIZED("INITIALIZED", 0), - ONLINE("ONLINE", 1), - OFFLINE("OFFLINE", 2), - DELETED("DELETED", 3), - UNAVAILABLE("UNAVAILABLE", 4), - UNKNOWN("UNKNOWN", -1); + INITIALIZED("INITIALIZED", 0), ONLINE("ONLINE", 1), OFFLINE("OFFLINE", 2), DELETED("DELETED", + 3), UNAVAILABLE("UNAVAILABLE", 4), UNKNOWN("UNKNOWN", -1); private String status; private Integer code; diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java index 89c107577..cd965f292 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/Text2SQLType.java @@ -1,9 +1,7 @@ package com.tencent.supersonic.common.pojo.enums; public enum Text2SQLType { - ONLY_RULE, - ONLY_LLM, - RULE_AND_LLM; + ONLY_RULE, ONLY_LLM, RULE_AND_LLM; public boolean enableRule() { return this.equals(ONLY_RULE) || this.equals(RULE_AND_LLM); diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java index 22502c7bd..b12e2c9c1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TimeDimensionEnum.java @@ -31,33 +31,23 @@ public enum TimeDimensionEnum { } public static List getNameList() { - return Arrays.stream(TimeDimensionEnum.values()) - .map(TimeDimensionEnum::getName) + return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getName) .collect(Collectors.toList()); } public static List getChNameList() { - return Arrays.stream(TimeDimensionEnum.values()) - .map(TimeDimensionEnum::getChName) + return Arrays.stream(TimeDimensionEnum.values()).map(TimeDimensionEnum::getChName) .collect(Collectors.toList()); } public static Map getChNameToNameMap() { - return Arrays.stream(TimeDimensionEnum.values()) - .collect( - Collectors.toMap( - TimeDimensionEnum::getChName, - TimeDimensionEnum::getName, - (k1, k2) -> k1)); + return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors + .toMap(TimeDimensionEnum::getChName, TimeDimensionEnum::getName, (k1, k2) -> k1)); } public static Map getNameToNameMap() { - return Arrays.stream(TimeDimensionEnum.values()) - .collect( - Collectors.toMap( - TimeDimensionEnum::getName, - TimeDimensionEnum::getName, - (k1, k2) -> k1)); + return Arrays.stream(TimeDimensionEnum.values()).collect(Collectors + .toMap(TimeDimensionEnum::getName, TimeDimensionEnum::getName, (k1, k2) -> k1)); } public String getName() { diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java index f78b2dee8..4d8b12e1e 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/TypeEnums.java @@ -1,13 +1,5 @@ package com.tencent.supersonic.common.pojo.enums; public enum TypeEnums { - METRIC, - DIMENSION, - TAG_OBJECT, - TAG, - DOMAIN, - ENTITY, - DATASET, - MODEL, - UNKNOWN + METRIC, DIMENSION, TAG_OBJECT, TAG, DOMAIN, ENTITY, DATASET, MODEL, UNKNOWN } diff --git a/common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java b/common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java index 12e1c6dd6..a3c06ead4 100644 --- a/common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java +++ b/common/src/main/java/com/tencent/supersonic/common/rest/SystemConfigController.java @@ -13,7 +13,8 @@ import org.springframework.web.bind.annotation.RestController; @RequestMapping({"/api/semantic/parameter"}) public class SystemConfigController { - @Autowired private SystemConfigService sysConfigService; + @Autowired + private SystemConfigService sysConfigService; @PostMapping public Boolean save(@RequestBody SystemConfig systemConfig) { diff --git a/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java b/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java index b36f9f06f..c83685648 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/EmbeddingService.java @@ -16,8 +16,8 @@ public interface EmbeddingService { void deleteQuery(String collectionName, List queries); - List retrieveQuery( - String collectionName, RetrieveQuery retrieveQuery, int num); + List retrieveQuery(String collectionName, RetrieveQuery retrieveQuery, + int num); void removeAll(); } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java index 5a3bd0322..c486a59c4 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/EmbeddingServiceImpl.java @@ -37,11 +37,8 @@ import java.util.stream.Collectors; @Slf4j public class EmbeddingServiceImpl implements EmbeddingService { - private Cache cache = - CacheBuilder.newBuilder() - .maximumSize(10000) - .expireAfterWrite(10, TimeUnit.HOURS) - .build(); + private Cache cache = CacheBuilder.newBuilder().maximumSize(10000) + .expireAfterWrite(10, TimeUnit.HOURS).build(); @Override public void addQuery(String collectionName, List queries) { @@ -59,17 +56,14 @@ public class EmbeddingServiceImpl implements EmbeddingService { embeddingStore.add(embedding, query); cache.put(TextSegmentConvert.getQueryId(query), true); } catch (Exception e) { - log.error( - "embeddingModel embed error question: {}, embeddingStore: {}", - question, - embeddingStore.getClass().getSimpleName(), - e); + log.error("embeddingModel embed error question: {}, embeddingStore: {}", question, + embeddingStore.getClass().getSimpleName(), e); } } } - private boolean existSegment( - EmbeddingStore embeddingStore, TextSegment query, Embedding embedding) { + private boolean existSegment(EmbeddingStore embeddingStore, TextSegment query, + Embedding embedding) { String queryId = TextSegmentConvert.getQueryId(query); if (queryId == null) { return false; @@ -82,13 +76,8 @@ public class EmbeddingServiceImpl implements EmbeddingService { Map filterCondition = new HashMap<>(); filterCondition.put(TextSegmentConvert.QUERY_ID, queryId); Filter filter = createCombinedFilter(filterCondition); - EmbeddingSearchRequest request = - EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .filter(filter) - .minScore(1.0d) - .maxResults(1) - .build(); + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder().queryEmbedding(embedding) + .filter(filter).minScore(1.0d).maxResults(1).build(); EmbeddingSearchResult result = embeddingStore.search(request); List> relevant = result.matches(); @@ -104,10 +93,8 @@ public class EmbeddingServiceImpl implements EmbeddingService { try { List queryIds = - queries.stream() - .map(textSegment -> TextSegmentConvert.getQueryId(textSegment)) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + queries.stream().map(textSegment -> TextSegmentConvert.getQueryId(textSegment)) + .filter(Objects::nonNull).collect(Collectors.toList()); if (CollectionUtils.isNotEmpty(queryIds)) { MetadataFilterBuilder filterBuilder = new MetadataFilterBuilder(TextSegmentConvert.QUERY_ID); @@ -122,21 +109,15 @@ public class EmbeddingServiceImpl implements EmbeddingService { } @Override - public List retrieveQuery( - String collectionName, RetrieveQuery retrieveQuery, int num) { + public List retrieveQuery(String collectionName, + RetrieveQuery retrieveQuery, int num) { EmbeddingStore embeddingStore = EmbeddingStoreFactoryProvider.getFactory().create(collectionName); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(); Map filterCondition = retrieveQuery.getFilterCondition(); - return retrieveQuery.getQueryTextsList().stream() - .map( - queryText -> - retrieveSingleQuery( - queryText, - embeddingModel, - embeddingStore, - filterCondition, - num)) + return retrieveQuery + .getQueryTextsList().stream().map(queryText -> retrieveSingleQuery(queryText, + embeddingModel, embeddingStore, filterCondition, num)) .collect(Collectors.toList()); } @@ -152,28 +133,17 @@ public class EmbeddingServiceImpl implements EmbeddingService { cache.invalidateAll(); } - private RetrieveQueryResult retrieveSingleQuery( - String queryText, - EmbeddingModel embeddingModel, - EmbeddingStore embeddingStore, - Map filterCondition, - int num) { + private RetrieveQueryResult retrieveSingleQuery(String queryText, EmbeddingModel embeddingModel, + EmbeddingStore embeddingStore, Map filterCondition, int num) { Embedding embeddedText = embeddingModel.embed(queryText).content(); Filter filter = createCombinedFilter(filterCondition); - EmbeddingSearchRequest request = - EmbeddingSearchRequest.builder() - .queryEmbedding(embeddedText) - .filter(filter) - .maxResults(num) - .build(); + EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddedText).filter(filter).maxResults(num).build(); EmbeddingSearchResult result = embeddingStore.search(request); - List retrievals = - result.matches().stream() - .map(this::convertToRetrieval) - .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)) - .limit(num) - .collect(Collectors.toList()); + List retrievals = result.matches().stream().map(this::convertToRetrieval) + .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)).limit(num) + .collect(Collectors.toList()); RetrieveQueryResult retrieveQueryResult = new RetrieveQueryResult(); retrieveQueryResult.setQuery(queryText); @@ -209,10 +179,8 @@ public class EmbeddingServiceImpl implements EmbeddingService { // Create an OR filter for each value in the list for (String value : (List) fieldValue) { IsEqualTo equalToFilter = new IsEqualTo(fieldName, value); - fieldFilter = - (fieldFilter == null) - ? equalToFilter - : Filter.or(fieldFilter, equalToFilter); + fieldFilter = (fieldFilter == null) ? equalToFilter + : Filter.or(fieldFilter, equalToFilter); } } else if (fieldValue instanceof String) { // Create a simple equality filter @@ -220,10 +188,8 @@ public class EmbeddingServiceImpl implements EmbeddingService { } // Combine the current field filter with the overall filter using AND logic if (fieldFilter != null) { - combinedFilter = - (combinedFilter == null) - ? fieldFilter - : Filter.and(combinedFilter, fieldFilter); + combinedFilter = (combinedFilter == null) ? fieldFilter + : Filter.and(combinedFilter, fieldFilter); } } return combinedFilter; diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java index e880d7b12..7819fa8e1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/ExemplarServiceImpl.java @@ -35,14 +35,15 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { private final ObjectMapper objectMapper = JsonUtil.INSTANCE.getObjectMapper(); - @Autowired private EmbeddingConfig embeddingConfig; + @Autowired + private EmbeddingConfig embeddingConfig; - @Autowired private EmbeddingService embeddingService; + @Autowired + private EmbeddingService embeddingService; public void storeExemplar(String collection, Text2SQLExemplar exemplar) { - Metadata metadata = - Metadata.from( - JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); + Metadata metadata = Metadata + .from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); TextSegmentConvert.addQueryId(segment, exemplar.getQuestion()); @@ -50,9 +51,8 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { } public void removeExemplar(String collection, Text2SQLExemplar exemplar) { - Metadata metadata = - Metadata.from( - JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); + Metadata metadata = Metadata + .from(JsonUtil.toMap(JsonUtil.toString(exemplar), String.class, Object.class)); TextSegment segment = TextSegment.from(exemplar.getQuestion(), metadata); TextSegmentConvert.addQueryId(segment, exemplar.getQuestion()); @@ -70,18 +70,11 @@ public class ExemplarServiceImpl implements ExemplarService, CommandLineRunner { RetrieveQuery.builder().queryTextsList(Lists.newArrayList(query)).build(); List results = embeddingService.retrieveQuery(collection, retrieveQuery, num); - results.stream() - .forEach( - ret -> { - ret.getRetrieval().stream() - .forEach( - r -> { - exemplars.add( - JsonUtil.mapToObject( - r.getMetadata(), - Text2SQLExemplar.class)); - }); - }); + results.stream().forEach(ret -> { + ret.getRetrieval().stream().forEach(r -> { + exemplars.add(JsonUtil.mapToObject(r.getMetadata(), Text2SQLExemplar.class)); + }); + }); return exemplars; } diff --git a/common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java b/common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java index 4b48331ad..69c1410fa 100644 --- a/common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java +++ b/common/src/main/java/com/tencent/supersonic/common/service/impl/SystemConfigServiceImpl.java @@ -21,7 +21,8 @@ import java.util.concurrent.atomic.AtomicReference; public class SystemConfigServiceImpl extends ServiceImpl implements SystemConfigService { - @Autowired private Environment environment; + @Autowired + private Environment environment; // Cache field to store the system configuration private AtomicReference cachedSystemConfig = new AtomicReference<>(); @@ -44,13 +45,11 @@ public class SystemConfigServiceImpl extends ServiceImpl { - if (environment.containsProperty(p.getName())) { - p.setValue(environment.getProperty(p.getName())); - } - }); + systemConfig.getParameters().stream().forEach(p -> { + if (environment.containsProperty(p.getName())) { + p.setValue(environment.getProperty(p.getName())); + } + }); save(systemConfig); return systemConfig; } @@ -68,9 +67,8 @@ public class SystemConfigServiceImpl extends ServiceImpl parameters = - JsonUtil.toObject( - systemConfigDO.getParameters(), new TypeReference>() {}); + List parameters = JsonUtil.toObject(systemConfigDO.getParameters(), + new TypeReference>() {}); sysParameter.setParameters(parameters); sysParameter.setAdminList(systemConfigDO.getAdmin()); return sysParameter; diff --git a/common/src/main/java/com/tencent/supersonic/common/util/AESEncryptionUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/AESEncryptionUtil.java index c7fe1b81b..6686e2447 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/AESEncryptionUtil.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/AESEncryptionUtil.java @@ -1,13 +1,12 @@ package com.tencent.supersonic.common.util; +import lombok.extern.slf4j.Slf4j; + import javax.crypto.Cipher; import javax.crypto.SecretKeyFactory; import javax.crypto.spec.IvParameterSpec; import javax.crypto.spec.PBEKeySpec; import javax.crypto.spec.SecretKeySpec; - -import lombok.extern.slf4j.Slf4j; - import java.security.MessageDigest; import java.security.spec.KeySpec; import java.util.Arrays; @@ -121,10 +120,8 @@ public class AESEncryptionUtil { int len = hexString.length(); byte[] byteArray = new byte[len / 2]; for (int i = 0; i < len; i += 2) { - byteArray[i / 2] = - (byte) - ((Character.digit(hexString.charAt(i), 16) << 4) - + Character.digit(hexString.charAt(i + 1), 16)); + byteArray[i / 2] = (byte) ((Character.digit(hexString.charAt(i), 16) << 4) + + Character.digit(hexString.charAt(i + 1), 16)); } return byteArray; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DateModeUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/DateModeUtils.java index c2ebfd990..a5753c085 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/DateModeUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/DateModeUtils.java @@ -62,12 +62,10 @@ public class DateModeUtils { * @return */ public String hasDataModeStr(ItemDateResp dateDate, DateConf dateInfo) { - if (Objects.isNull(dateDate) - || StringUtils.isEmpty(dateDate.getStartDate()) + if (Objects.isNull(dateDate) || StringUtils.isEmpty(dateDate.getStartDate()) || StringUtils.isEmpty(dateDate.getStartDate())) { - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateCol, dateInfo.getStartDate(), sysDateCol, dateInfo.getEndDate()); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(), + sysDateCol, dateInfo.getEndDate()); } else { log.info("dateDate:{}", dateDate); } @@ -81,31 +79,22 @@ public class DateModeUtils { if (endReq.isAfter(endData)) { if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) { - Long unit = - getInterval( - dateInfo.getStartDate(), - dateInfo.getEndDate(), - dateFormatStr, - ChronoUnit.DAYS); + Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(), + dateFormatStr, ChronoUnit.DAYS); LocalDate dateMax = endData; LocalDate dateMin = dateMax.minusDays(unit - 1); - return String.format( - "(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, dateMax); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, + dateMax); } if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) { - Long unit = - getInterval( - dateInfo.getStartDate(), - dateInfo.getEndDate(), - dateFormatStr, - ChronoUnit.MONTHS); + Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(), + dateFormatStr, ChronoUnit.MONTHS); return generateMonthSql(endData, unit, dateFormatStr); } } - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateCol, dateInfo.getStartDate(), sysDateCol, dateInfo.getEndDate()); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateInfo.getStartDate(), + sysDateCol, dateInfo.getEndDate()); } public String generateMonthSql(LocalDate endData, Long unit, String dateFormatStr) { @@ -131,9 +120,8 @@ public class DateModeUtils { public String recentDayStr(ItemDateResp dateDate, DateConf dateInfo) { ImmutablePair dayRange = recentDay(dateDate, dateInfo); - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateCol, dayRange.left, sysDateCol, dayRange.right); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dayRange.left, sysDateCol, + dayRange.right); } public ImmutablePair recentDay(ItemDateResp dateDate, DateConf dateInfo) { @@ -143,7 +131,7 @@ public class DateModeUtils { } DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr); LocalDate end = LocalDate.parse(dateDate.getEndDate(), formatter); - // todo unavailableDateList logic + // todo unavailableDateList logic Integer unit = dateInfo.getUnit() - 1; String start = end.minusDays(unit).format(formatter); @@ -154,16 +142,15 @@ public class DateModeUtils { DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormatStr); String endStr = endData.format(formatter); String start = endData.minusMonths(unit).format(formatter); - return String.format( - "(%s >= '%s' and %s <= '%s')", sysDateMonthCol, start, sysDateMonthCol, endStr); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, start, sysDateMonthCol, + endStr); } public String recentMonthStr(ItemDateResp dateDate, DateConf dateInfo) { List> range = recentMonth(dateDate, dateInfo); if (range.size() == 1) { - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateMonthCol, range.get(0).left, sysDateMonthCol, range.get(0).right); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateMonthCol, range.get(0).left, + sysDateMonthCol, range.get(0).right); } if (range.size() > 0) { StringJoiner joiner = new StringJoiner(","); @@ -173,21 +160,15 @@ public class DateModeUtils { return ""; } - public List> recentMonth( - ItemDateResp dateDate, DateConf dateInfo) { - LocalDate endData = - LocalDate.parse( - dateDate.getEndDate(), - DateTimeFormatter.ofPattern(dateDate.getDateFormat())); + public List> recentMonth(ItemDateResp dateDate, + DateConf dateInfo) { + LocalDate endData = LocalDate.parse(dateDate.getEndDate(), + DateTimeFormatter.ofPattern(dateDate.getDateFormat())); List> ret = new ArrayList<>(); if (dateDate.getDatePeriod() != null && DatePeriodEnum.MONTH.equals(dateDate.getDatePeriod())) { - Long unit = - getInterval( - dateInfo.getStartDate(), - dateInfo.getEndDate(), - dateDate.getDateFormat(), - ChronoUnit.MONTHS); + Long unit = getInterval(dateInfo.getStartDate(), dateInfo.getEndDate(), + dateDate.getDateFormat(), ChronoUnit.MONTHS); LocalDate dateMax = endData; List months = generateMonthStr(dateMax, unit, dateDate.getDateFormat()); if (!CollectionUtils.isEmpty(months)) { @@ -207,16 +188,14 @@ public class DateModeUtils { public String recentWeekStr(LocalDate endData, Long unit) { DateTimeFormatter formatter = DateTimeFormatter.ofPattern(DAY_FORMAT); String start = endData.minusDays(unit * 7).format(formatter); - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateWeekCol, start, sysDateWeekCol, endData.format(formatter)); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, start, sysDateWeekCol, + endData.format(formatter)); } public String recentWeekStr(ItemDateResp dateDate, DateConf dateInfo) { ImmutablePair dayRange = recentWeek(dateDate, dateInfo); - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateWeekCol, dayRange.left, sysDateWeekCol, dayRange.right); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateWeekCol, dayRange.left, + sysDateWeekCol, dayRange.right); } public ImmutablePair recentWeek(ItemDateResp dateDate, DateConf dateInfo) { @@ -231,8 +210,8 @@ public class DateModeUtils { return ImmutablePair.of(start, end.format(formatter)); } - private Long getInterval( - String startDate, String endDate, String dateFormat, ChronoUnit chronoUnit) { + private Long getInterval(String startDate, String endDate, String dateFormat, + ChronoUnit chronoUnit) { DateTimeFormatter formatter = DateTimeFormatter.ofPattern(dateFormat); try { LocalDate start = LocalDate.parse(startDate, formatter); @@ -270,34 +249,23 @@ public class DateModeUtils { if (DatePeriodEnum.MONTH.equals(dateInfo.getPeriod())) { // startDate YYYYMM if (!dateInfo.getStartDate().contains(Constants.MINUS)) { - return String.format( - "%s >= '%s' and %s <= '%s'", - sysDateMonthCol, - dateInfo.getStartDate(), - sysDateMonthCol, - dateInfo.getEndDate()); + return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol, + dateInfo.getStartDate(), sysDateMonthCol, dateInfo.getEndDate()); } LocalDate endData = LocalDate.parse(dateInfo.getEndDate(), DateTimeFormatter.ofPattern(DAY_FORMAT)); - LocalDate startData = - LocalDate.parse( - dateInfo.getStartDate(), DateTimeFormatter.ofPattern(DAY_FORMAT)); + LocalDate startData = LocalDate.parse(dateInfo.getStartDate(), + DateTimeFormatter.ofPattern(DAY_FORMAT)); DateTimeFormatter formatter = DateTimeFormatter.ofPattern(MONTH_FORMAT); - return String.format( - "%s >= '%s' and %s <= '%s'", - sysDateMonthCol, - startData.format(formatter), - sysDateMonthCol, - endData.format(formatter)); + return String.format("%s >= '%s' and %s <= '%s'", sysDateMonthCol, + startData.format(formatter), sysDateMonthCol, endData.format(formatter)); } if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) { - return String.format( - "%s >= '%s' and %s <= '%s'", - sysDateWeekCol, dateInfo.getStartDate(), sysDateWeekCol, dateInfo.getEndDate()); + return String.format("%s >= '%s' and %s <= '%s'", sysDateWeekCol, + dateInfo.getStartDate(), sysDateWeekCol, dateInfo.getEndDate()); } - return String.format( - "%s >= '%s' and %s <= '%s'", - sysDateCol, dateInfo.getStartDate(), sysDateCol, dateInfo.getEndDate()); + return String.format("%s >= '%s' and %s <= '%s'", sysDateCol, dateInfo.getStartDate(), + sysDateCol, dateInfo.getEndDate()); } /** @@ -335,8 +303,8 @@ public class DateModeUtils { if (DatePeriodEnum.DAY.equals(dateInfo.getPeriod())) { LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMin = dateMax.minusDays(unit - 1); - return String.format( - "(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, dateMax); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, dateMin, sysDateCol, + dateMax); } if (DatePeriodEnum.WEEK.equals(dateInfo.getPeriod())) { @@ -352,9 +320,8 @@ public class DateModeUtils { return recentMonthStr(dateMax, unit.longValue() * 12, MONTH_FORMAT); } - return String.format( - "(%s >= '%s' and %s <= '%s')", - sysDateCol, LocalDate.now().minusDays(2), sysDateCol, LocalDate.now().minusDays(1)); + return String.format("(%s >= '%s' and %s <= '%s')", sysDateCol, + LocalDate.now().minusDays(2), sysDateCol, LocalDate.now().minusDays(1)); } public String getDateWhereStr(DateConf dateInfo) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java index 10101e237..1b2007ed9 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/DateUtils.java @@ -90,8 +90,8 @@ public class DateUtils { return startDate.format(DEFAULT_DATE_FORMATTER2); } - public static String getBeforeDate( - String currentDate, int intervalDay, DatePeriodEnum datePeriodEnum) { + public static String getBeforeDate(String currentDate, int intervalDay, + DatePeriodEnum datePeriodEnum) { LocalDate specifiedDate = LocalDate.parse(currentDate, DEFAULT_DATE_FORMATTER2); LocalDate result = null; switch (datePeriodEnum) { @@ -101,9 +101,8 @@ public class DateUtils { case WEEK: result = specifiedDate.minusWeeks(intervalDay); if (intervalDay == 0) { - result = - result.with( - TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY)); + result = result + .with(TemporalAdjusters.previousOrSame(java.time.DayOfWeek.MONDAY)); } break; case MONTH: @@ -115,14 +114,13 @@ public class DateUtils { case QUARTER: result = specifiedDate.minusMonths(intervalDay * 3L); if (intervalDay == 0) { - TemporalAdjuster firstDayOfQuarter = - temporal -> { - LocalDate tempDate = LocalDate.from(temporal); - int month = tempDate.get(ChronoField.MONTH_OF_YEAR); - int firstMonthOfQuarter = ((month - 1) / 3) * 3 + 1; - return tempDate.with(ChronoField.MONTH_OF_YEAR, firstMonthOfQuarter) - .with(TemporalAdjusters.firstDayOfMonth()); - }; + TemporalAdjuster firstDayOfQuarter = temporal -> { + LocalDate tempDate = LocalDate.from(temporal); + int month = tempDate.get(ChronoField.MONTH_OF_YEAR); + int firstMonthOfQuarter = ((month - 1) / 3) * 3 + 1; + return tempDate.with(ChronoField.MONTH_OF_YEAR, firstMonthOfQuarter) + .with(TemporalAdjusters.firstDayOfMonth()); + }; result = result.with(firstDayOfQuarter); } break; @@ -162,8 +160,8 @@ public class DateUtils { return !timeString.equals("00:00:00"); } - public static List getDateList( - String startDateStr, String endDateStr, DatePeriodEnum period) { + public static List getDateList(String startDateStr, String endDateStr, + DatePeriodEnum period) { try { LocalDate startDate = LocalDate.parse(startDateStr); LocalDate endDate = LocalDate.parse(endDateStr); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/FileUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/FileUtils.java index dc0ceea26..182abf564 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/FileUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/FileUtils.java @@ -22,12 +22,8 @@ public class FileUtils { return -1; } File file = new File(path); - Optional lastModified = - Arrays.stream(file.listFiles()) - .filter(f -> f.isFile()) - .map(f -> f.lastModified()) - .sorted(Collections.reverseOrder()) - .findFirst(); + Optional lastModified = Arrays.stream(file.listFiles()).filter(f -> f.isFile()) + .map(f -> f.lastModified()).sorted(Collections.reverseOrder()).findFirst(); if (lastModified.isPresent()) { return lastModified.get(); @@ -42,8 +38,8 @@ public class FileUtils { return null; } - public static void scanDirectory( - File file, int maxLevel, Map> directories) { + public static void scanDirectory(File file, int maxLevel, + Map> directories) { if (maxLevel < 0) { return; } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/HttpClientUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/HttpClientUtils.java index fe4a31f91..63d186f3c 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/HttpClientUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/HttpClientUtils.java @@ -77,34 +77,25 @@ public class HttpClientUtils { private static void init() { try { - SSLConnectionSocketFactory sslConnectionSocketFactory = - new SSLConnectionSocketFactory( - SSLContexts.custom() - .loadTrustMaterial((chain, authType) -> true) - .build(), - new String[] {"SSLv2Hello", "SSLv3", "TLSv1", "TLSv1.1", "TLSv1.2"}, - null, - NoopHostnameVerifier.INSTANCE); + SSLConnectionSocketFactory sslConnectionSocketFactory = new SSLConnectionSocketFactory( + SSLContexts.custom().loadTrustMaterial((chain, authType) -> true).build(), + new String[] {"SSLv2Hello", "SSLv3", "TLSv1", "TLSv1.1", "TLSv1.2"}, null, + NoopHostnameVerifier.INSTANCE); - PoolingHttpClientConnectionManager connManager = - new PoolingHttpClientConnectionManager( - RegistryBuilder.create() - .register( - "http", PlainConnectionSocketFactory.getSocketFactory()) - .register("https", sslConnectionSocketFactory) - .build()); + PoolingHttpClientConnectionManager connManager = new PoolingHttpClientConnectionManager( + RegistryBuilder.create() + .register("http", PlainConnectionSocketFactory.getSocketFactory()) + .register("https", sslConnectionSocketFactory).build()); connManager.setMaxTotal(DEFAULT_MAX_TOTAL_CONN); connManager.setDefaultMaxPerRoute(DEFAULT_MAX_CONN_PERHOST); - RequestConfig requestConfig = - RequestConfig.custom() - // 请求超时时间 - .setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT) - // 等待数据超时时间 - .setSocketTimeout(DEFAULT_READ_TIMEOUT) - // 连接不够用时等待超时时间 - .setConnectionRequestTimeout(DEFAULT_CONN_REQUEST_TIMEOUT) - .build(); + RequestConfig requestConfig = RequestConfig.custom() + // 请求超时时间 + .setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT) + // 等待数据超时时间 + .setSocketTimeout(DEFAULT_READ_TIMEOUT) + // 连接不够用时等待超时时间 + .setConnectionRequestTimeout(DEFAULT_CONN_REQUEST_TIMEOUT).build(); HttpRequestRetryHandler httpRequestRetryHandler = (exception, executionCount, context) -> { @@ -116,49 +107,39 @@ public class HttpClientUtils { } if (exception instanceof NoHttpResponseException) { // 如果服务器丢掉了连接,那么就重试 - log.warn( - "Retry, No response from server on {} error: {}", - executionCount, - exception.getMessage()); + log.warn("Retry, No response from server on {} error: {}", + executionCount, exception.getMessage()); return true; } else if (exception instanceof SocketException) { // 如果服务器断开了连接,那么就重试 - log.warn( - "Retry, No connection from server on {} error: {}", - executionCount, - exception.getMessage()); + log.warn("Retry, No connection from server on {} error: {}", + executionCount, exception.getMessage()); return true; } return false; }; - httpClient = - HttpClients.custom() - // 设置连接池 - .setConnectionManager(connManager) - // 设置超时时间 - .setDefaultRequestConfig(requestConfig) - // 设置连接存活时间 - .setKeepAliveStrategy( - new DefaultConnectionKeepAliveStrategy() { - @Override - public long getKeepAliveDuration( - final HttpResponse response, - final HttpContext context) { - long keepAlive = - super.getKeepAliveDuration(response, context); - if (keepAlive == -1) { - keepAlive = 5000; - } - return keepAlive; - } - }) - .setRetryHandler(httpRequestRetryHandler) - // 设置连接存活时间 - .setConnectionTimeToLive(5000L, TimeUnit.MILLISECONDS) - // 关闭无效和空闲的连接 - .evictIdleConnections(5L, TimeUnit.SECONDS) - .build(); + httpClient = HttpClients.custom() + // 设置连接池 + .setConnectionManager(connManager) + // 设置超时时间 + .setDefaultRequestConfig(requestConfig) + // 设置连接存活时间 + .setKeepAliveStrategy(new DefaultConnectionKeepAliveStrategy() { + @Override + public long getKeepAliveDuration(final HttpResponse response, + final HttpContext context) { + long keepAlive = super.getKeepAliveDuration(response, context); + if (keepAlive == -1) { + keepAlive = 5000; + } + return keepAlive; + } + }).setRetryHandler(httpRequestRetryHandler) + // 设置连接存活时间 + .setConnectionTimeToLive(5000L, TimeUnit.MILLISECONDS) + // 关闭无效和空闲的连接 + .evictIdleConnections(5L, TimeUnit.SECONDS).build(); } catch (Exception e) { log.error(e.getMessage(), e); throw new RuntimeException(e); @@ -193,45 +174,34 @@ public class HttpClientUtils { * * @return */ - public static HttpClientResult doPost( - String url, - String proxyHost, - Integer proxyPort, - Map headers, - Map params) { - return RetryUtils.exec( - () -> { - HttpPost httpPost = null; - CloseableHttpResponse response = null; - try { - httpPost = new HttpPost(url); - setProxy(httpPost, proxyHost, proxyPort); + public static HttpClientResult doPost(String url, String proxyHost, Integer proxyPort, + Map headers, Map params) { + return RetryUtils.exec(() -> { + HttpPost httpPost = null; + CloseableHttpResponse response = null; + try { + httpPost = new HttpPost(url); + setProxy(httpPost, proxyHost, proxyPort); - // 封装header参数 - packageHeader(headers, httpPost); - // 封装请求参数 - packageParam(params, httpPost); + // 封装header参数 + packageHeader(headers, httpPost); + // 封装请求参数 + packageParam(params, httpPost); - response = httpClient.execute(httpPost); - // 获取返回结果 - HttpClientResult result = getHttpClientResult(response); - log.info( - "uri:{}, req:{}, resp:{}", - url, - "headers:" + getHeaders(httpPost) + "------params:" + params, - result); - return result; - } catch (Exception e) { - log.error( - "uri:{}, req:{}", - url, - "headers:" + headers + "------params:" + params, - e); - throw new RuntimeException(e.getMessage()); - } finally { - close(httpPost, response); - } - }); + response = httpClient.execute(httpPost); + // 获取返回结果 + HttpClientResult result = getHttpClientResult(response); + log.info("uri:{}, req:{}, resp:{}", url, + "headers:" + getHeaders(httpPost) + "------params:" + params, result); + return result; + } catch (Exception e) { + log.error("uri:{}, req:{}", url, "headers:" + headers + "------params:" + params, + e); + throw new RuntimeException(e.getMessage()); + } finally { + close(httpPost, response); + } + }); } /** @@ -242,8 +212,8 @@ public class HttpClientUtils { * @param params * @return */ - public static HttpClientResult doPost( - String url, Map header, Map params) { + public static HttpClientResult doPost(String url, Map header, + Map params) { return doPost(url, null, null, header, params); } @@ -279,53 +249,42 @@ public class HttpClientUtils { * @return * @throws Exception */ - public static HttpClientResult doGet( - String url, - String proxyHost, - Integer proxyPort, - Map headers, - Map params) { - return RetryUtils.exec( - () -> { - HttpGet httpGet = null; - CloseableHttpResponse response = null; - try { - // 创建访问的地址 - URIBuilder uriBuilder = new URIBuilder(url); - if (params != null) { - Set> entrySet = params.entrySet(); - for (Map.Entry entry : entrySet) { - uriBuilder.setParameter(entry.getKey(), entry.getValue()); - } - } - - httpGet = new HttpGet(uriBuilder.build()); - setProxy(httpGet, proxyHost, proxyPort); - - // 设置请求头 - packageHeader(headers, httpGet); - - response = httpClient.execute(httpGet); - - // 获取返回结果 - HttpClientResult res = getHttpClientResult(response); - log.debug( - "GET uri:{}, req:{}, resp:{}", - url, - "headers:" + getHeaders(httpGet) + "------params:" + params, - res); - return res; - } catch (Exception e) { - log.error( - "GET error! uri:{}, req:{}", - url, - "headers:" + headers + "------params:" + params, - e); - throw new RuntimeException(e.getMessage()); - } finally { - close(httpGet, response); + public static HttpClientResult doGet(String url, String proxyHost, Integer proxyPort, + Map headers, Map params) { + return RetryUtils.exec(() -> { + HttpGet httpGet = null; + CloseableHttpResponse response = null; + try { + // 创建访问的地址 + URIBuilder uriBuilder = new URIBuilder(url); + if (params != null) { + Set> entrySet = params.entrySet(); + for (Map.Entry entry : entrySet) { + uriBuilder.setParameter(entry.getKey(), entry.getValue()); } - }); + } + + httpGet = new HttpGet(uriBuilder.build()); + setProxy(httpGet, proxyHost, proxyPort); + + // 设置请求头 + packageHeader(headers, httpGet); + + response = httpClient.execute(httpGet); + + // 获取返回结果 + HttpClientResult res = getHttpClientResult(response); + log.debug("GET uri:{}, req:{}, resp:{}", url, + "headers:" + getHeaders(httpGet) + "------params:" + params, res); + return res; + } catch (Exception e) { + log.error("GET error! uri:{}, req:{}", url, + "headers:" + headers + "------params:" + params, e); + throw new RuntimeException(e.getMessage()); + } finally { + close(httpGet, response); + } + }); } /** @@ -336,8 +295,8 @@ public class HttpClientUtils { * @param params * @return */ - public static HttpClientResult doGet( - String url, Map header, Map params) { + public static HttpClientResult doGet(String url, Map header, + Map params) { return doGet(url, null, null, header, params); } @@ -399,9 +358,8 @@ public class HttpClientUtils { * @param httpMethod * @throws UnsupportedEncodingException */ - public static void packageParam( - Map params, HttpEntityEnclosingRequestBase httpMethod) - throws UnsupportedEncodingException { + public static void packageParam(Map params, + HttpEntityEnclosingRequestBase httpMethod) throws UnsupportedEncodingException { if (params != null) { List nvps = new ArrayList(); Set> entrySet = params.entrySet(); @@ -416,13 +374,9 @@ public class HttpClientUtils { public static void setProxy(HttpRequestBase httpMethod, String proxyHost, Integer proxyPort) { if (!StringUtils.isEmpty(proxyHost) && proxyPort != null) { - RequestConfig config = - RequestConfig.custom() - .setProxy(new HttpHost(proxyHost, proxyPort)) - .setConnectTimeout(10000) - .setSocketTimeout(10000) - .setConnectionRequestTimeout(3000) - .build(); + RequestConfig config = RequestConfig.custom() + .setProxy(new HttpHost(proxyHost, proxyPort)).setConnectTimeout(10000) + .setSocketTimeout(10000).setConnectionRequestTimeout(3000).build(); httpMethod.setConfig(config); } } @@ -437,50 +391,39 @@ public class HttpClientUtils { * * @return */ - public static HttpClientResult doPostJSON( - String url, - String proxyHost, - Integer proxyPort, - Map headers, - String req) { - return RetryUtils.exec( - () -> { - HttpPost httpPost = null; - CloseableHttpResponse response = null; - try { - httpPost = new HttpPost(url); - setProxy(httpPost, proxyHost, proxyPort); + public static HttpClientResult doPostJSON(String url, String proxyHost, Integer proxyPort, + Map headers, String req) { + return RetryUtils.exec(() -> { + HttpPost httpPost = null; + CloseableHttpResponse response = null; + try { + httpPost = new HttpPost(url); + setProxy(httpPost, proxyHost, proxyPort); - // 封装header参数 - packageHeader(headers, httpPost); - httpPost.setHeader("Content-Type", "application/json;charset=UTF-8"); + // 封装header参数 + packageHeader(headers, httpPost); + httpPost.setHeader("Content-Type", "application/json;charset=UTF-8"); - // 封装请求参数 - StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题 - stringEntity.setContentEncoding("UTF-8"); + // 封装请求参数 + StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题 + stringEntity.setContentEncoding("UTF-8"); - httpPost.setEntity(stringEntity); + httpPost.setEntity(stringEntity); - response = httpClient.execute(httpPost); - // 获取返回结果 - HttpClientResult res = getHttpClientResult(response); - log.info( - "doPostJSON uri:{}, req:{}, resp:{}", - url, - "headers:" + getHeaders(httpPost) + "------req:" + req, - res); - return res; - } catch (Exception e) { - log.error( - "doPostJSON error! uri:{}, req:{}", - url, - "headers:" + headers + "------req:" + req, - e); - throw new RuntimeException(e.getMessage()); - } finally { - close(httpPost, response); - } - }); + response = httpClient.execute(httpPost); + // 获取返回结果 + HttpClientResult res = getHttpClientResult(response); + log.info("doPostJSON uri:{}, req:{}, resp:{}", url, + "headers:" + getHeaders(httpPost) + "------req:" + req, res); + return res; + } catch (Exception e) { + log.error("doPostJSON error! uri:{}, req:{}", url, + "headers:" + headers + "------req:" + req, e); + throw new RuntimeException(e.getMessage()); + } finally { + close(httpPost, response); + } + }); } public static HttpClientResult doPostJSON(String url, String req) { @@ -488,56 +431,45 @@ public class HttpClientUtils { } /** get json */ - public static HttpClientResult doGetJSON( - String url, - String proxyHost, - Integer proxyPort, - Map headers, - Map params) { - return RetryUtils.exec( - () -> { - HttpGet httpGet = null; - CloseableHttpResponse response = null; - try { - // 创建访问的地址 - URIBuilder uriBuilder = new URIBuilder(url); - if (params != null) { - Set> entrySet = params.entrySet(); - for (Map.Entry entry : entrySet) { - uriBuilder.setParameter(entry.getKey(), entry.getValue()); - } - } - - httpGet = new HttpGet(uriBuilder.build()); - setProxy(httpGet, proxyHost, proxyPort); - - // 设置请求头 - packageHeader(headers, httpGet); - httpGet.setHeader("Content-Type", "application/json;charset=UTF-8"); - - response = httpClient.execute(httpGet); - - // 获取返回结果 - HttpClientResult res = getHttpClientResult(response); - - log.info( - "doGetJSON uri:{}, req:{}, resp:{}", - url, - "headers:" + getHeaders(httpGet) + "------params:" + params, - res); - - return res; - } catch (Exception e) { - log.warn( - "doGetJSON error! uri:{}, req:{}", - url, - "headers:" + headers + "------params:" + params, - e); - throw new RuntimeException(e.getMessage()); - } finally { - close(httpGet, response); + public static HttpClientResult doGetJSON(String url, String proxyHost, Integer proxyPort, + Map headers, Map params) { + return RetryUtils.exec(() -> { + HttpGet httpGet = null; + CloseableHttpResponse response = null; + try { + // 创建访问的地址 + URIBuilder uriBuilder = new URIBuilder(url); + if (params != null) { + Set> entrySet = params.entrySet(); + for (Map.Entry entry : entrySet) { + uriBuilder.setParameter(entry.getKey(), entry.getValue()); } - }); + } + + httpGet = new HttpGet(uriBuilder.build()); + setProxy(httpGet, proxyHost, proxyPort); + + // 设置请求头 + packageHeader(headers, httpGet); + httpGet.setHeader("Content-Type", "application/json;charset=UTF-8"); + + response = httpClient.execute(httpGet); + + // 获取返回结果 + HttpClientResult res = getHttpClientResult(response); + + log.info("doGetJSON uri:{}, req:{}, resp:{}", url, + "headers:" + getHeaders(httpGet) + "------params:" + params, res); + + return res; + } catch (Exception e) { + log.warn("doGetJSON error! uri:{}, req:{}", url, + "headers:" + headers + "------params:" + params, e); + throw new RuntimeException(e.getMessage()); + } finally { + close(httpGet, response); + } + }); } private static HttpClientResult getHttpClientResult(CloseableHttpResponse response) @@ -564,82 +496,63 @@ public class HttpClientUtils { * @param fullFilePath * @return */ - public static HttpClientResult doFileUploadBodyParams( - String url, - Map headers, - Map bodyParams, - String fullFilePath) { + public static HttpClientResult doFileUploadBodyParams(String url, Map headers, + Map bodyParams, String fullFilePath) { return doFileUpload(url, null, null, headers, null, bodyParams, fullFilePath); } - public static HttpClientResult doFileUpload( - String url, - String proxyHost, - Integer proxyPort, - Map headers, - Map params, - Map bodyParams, + public static HttpClientResult doFileUpload(String url, String proxyHost, Integer proxyPort, + Map headers, Map params, Map bodyParams, String fullFilePath) { - return RetryUtils.exec( - () -> { - InputStream inputStream = null; - CloseableHttpResponse response = null; - HttpPost httpPost = null; - try { + return RetryUtils.exec(() -> { + InputStream inputStream = null; + CloseableHttpResponse response = null; + HttpPost httpPost = null; + try { - File uploadFile = new File(fullFilePath); - inputStream = new FileInputStream(uploadFile); + File uploadFile = new File(fullFilePath); + inputStream = new FileInputStream(uploadFile); - httpPost = new HttpPost(url); - setProxy(httpPost, proxyHost, proxyPort); + httpPost = new HttpPost(url); + setProxy(httpPost, proxyHost, proxyPort); - packageHeader(headers, httpPost); + packageHeader(headers, httpPost); - HttpEntity entity = - getFileUploadHttpEntity( - params, bodyParams, inputStream, uploadFile.getName()); - httpPost.setEntity(entity); + HttpEntity entity = getFileUploadHttpEntity(params, bodyParams, inputStream, + uploadFile.getName()); + httpPost.setEntity(entity); - response = httpClient.execute(httpPost); - // 执行请求并获得响应结果 - HttpClientResult res = getHttpClientResult(response); - log.info( - "doFileUpload uri:{}, req:{}, resp:{}", - url, - "params:" + params + ", fullFilePath:" + fullFilePath, - res); - return res; - } catch (Exception e) { - log.error( - "doFileUpload error! uri:{}, req:{}", - url, - "params:" + params + ", fullFilePath:" + fullFilePath, - e); - throw new RuntimeException(e.getMessage()); - } finally { - try { - if (null != inputStream) { - inputStream.close(); - } - // 释放资源 - close(httpPost, response); - } catch (IOException e) { - log.error("HttpClientUtils release error!", e); - } + response = httpClient.execute(httpPost); + // 执行请求并获得响应结果 + HttpClientResult res = getHttpClientResult(response); + log.info("doFileUpload uri:{}, req:{}, resp:{}", url, + "params:" + params + ", fullFilePath:" + fullFilePath, res); + return res; + } catch (Exception e) { + log.error("doFileUpload error! uri:{}, req:{}", url, + "params:" + params + ", fullFilePath:" + fullFilePath, e); + throw new RuntimeException(e.getMessage()); + } finally { + try { + if (null != inputStream) { + inputStream.close(); } - }); + // 释放资源 + close(httpPost, response); + } catch (IOException e) { + log.error("HttpClientUtils release error!", e); + } + } + }); } - private static HttpEntity getFileUploadHttpEntity( - Map params, - Map bodyParams, - InputStream inputStream, - String fileName) + private static HttpEntity getFileUploadHttpEntity(Map params, + Map bodyParams, InputStream inputStream, String fileName) throws UnsupportedEncodingException { MultipartEntityBuilder builder = MultipartEntityBuilder.create(); builder.setMode(HttpMultipartMode.BROWSER_COMPATIBLE); - builder.addBinaryBody( - "file", inputStream, ContentType.create("multipart/form-data"), fileName); + builder.addBinaryBody("file", inputStream, ContentType.create("multipart/form-data"), + fileName); if (!CollectionUtils.isEmpty(bodyParams)) { for (String bodyParamsKey : bodyParams.keySet()) { @@ -649,8 +562,7 @@ public class HttpClientUtils { // 构建请求参数 普通表单项 if (!CollectionUtils.isEmpty(params)) { for (Map.Entry entry : params.entrySet()) { - builder.addPart( - entry.getKey(), + builder.addPart(entry.getKey(), new StringBody(entry.getValue(), ContentType.MULTIPART_FORM_DATA)); } } @@ -668,41 +580,34 @@ public class HttpClientUtils { * @return */ public static HttpClientResult doDelete(String url, Map headers, String req) { - return RetryUtils.exec( - () -> { - HttpDeleteWithBody httpDelete = null; - CloseableHttpResponse response = null; - try { - httpDelete = new HttpDeleteWithBody(url); - // 封装header参数 - packageHeader(headers, httpDelete); - httpDelete.setHeader("Content-Type", "application/json;charset=UTF-8"); - // 封装请求参数 - StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题 - stringEntity.setContentEncoding("UTF-8"); + return RetryUtils.exec(() -> { + HttpDeleteWithBody httpDelete = null; + CloseableHttpResponse response = null; + try { + httpDelete = new HttpDeleteWithBody(url); + // 封装header参数 + packageHeader(headers, httpDelete); + httpDelete.setHeader("Content-Type", "application/json;charset=UTF-8"); + // 封装请求参数 + StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题 + stringEntity.setContentEncoding("UTF-8"); - httpDelete.setEntity(stringEntity); + httpDelete.setEntity(stringEntity); - response = httpClient.execute(httpDelete); + response = httpClient.execute(httpDelete); - HttpClientResult res = getHttpClientResult(response); - log.info( - "doDeleteJSON uri:{}, req:{}, resp:{}", - url, - "headers:" + getHeaders(httpDelete) + "------req:" + req, - res); - return res; - } catch (Exception e) { - log.error( - "doDeleteJSON error! uri:{}, req:{}", - url, - "headers:" + headers + "------req:" + req, - e); - throw new RuntimeException(e.getMessage()); - } finally { - close(httpDelete, response); - } - }); + HttpClientResult res = getHttpClientResult(response); + log.info("doDeleteJSON uri:{}, req:{}, resp:{}", url, + "headers:" + getHeaders(httpDelete) + "------req:" + req, res); + return res; + } catch (Exception e) { + log.error("doDeleteJSON error! uri:{}, req:{}", url, + "headers:" + headers + "------req:" + req, e); + throw new RuntimeException(e.getMessage()); + } finally { + close(httpDelete, response); + } + }); } private static class HttpDeleteWithBody extends HttpEntityEnclosingRequestBase { @@ -730,37 +635,30 @@ public class HttpClientUtils { } public static HttpClientResult doPutJson(String url, Map headers, String req) { - return RetryUtils.exec( - () -> { - HttpPut httpPut = null; - CloseableHttpResponse response = null; - try { - httpPut = new HttpPut(url); - // 封装header参数 - packageHeader(headers, httpPut); - httpPut.setHeader("Content-Type", "application/json;charset=UTF-8"); - // 封装请求参数 - StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题 - stringEntity.setContentEncoding("UTF-8"); - httpPut.setEntity(stringEntity); - response = httpClient.execute(httpPut); - HttpClientResult res = getHttpClientResult(response); - log.info( - "doPutJSON uri:{}, req:{}, resp:{}", - url, - "headers:" + getHeaders(httpPut) + "------req:" + req, - res); - return res; - } catch (Exception e) { - log.error( - "doPutJSON error! uri:{}, req:{}", - url, - "headers:" + headers + "------req:" + req, - e); - throw new RuntimeException(e.getMessage()); - } finally { - close(httpPut, response); - } - }); + return RetryUtils.exec(() -> { + HttpPut httpPut = null; + CloseableHttpResponse response = null; + try { + httpPut = new HttpPut(url); + // 封装header参数 + packageHeader(headers, httpPut); + httpPut.setHeader("Content-Type", "application/json;charset=UTF-8"); + // 封装请求参数 + StringEntity stringEntity = new StringEntity(req, ENCODING); // 解决中文乱码问题 + stringEntity.setContentEncoding("UTF-8"); + httpPut.setEntity(stringEntity); + response = httpClient.execute(httpPut); + HttpClientResult res = getHttpClientResult(response); + log.info("doPutJSON uri:{}, req:{}, resp:{}", url, + "headers:" + getHeaders(httpPut) + "------req:" + req, res); + return res; + } catch (Exception e) { + log.error("doPutJSON error! uri:{}, req:{}", url, + "headers:" + headers + "------req:" + req, e); + throw new RuntimeException(e.getMessage()); + } finally { + close(httpPut, response); + } + }); } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/JsonUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/JsonUtil.java index 93516c92a..5e87a14df 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/JsonUtil.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/JsonUtil.java @@ -30,7 +30,8 @@ public class JsonUtil { public static final JsonUtil INSTANCE = new JsonUtil(); - @Getter private final ObjectMapper objectMapper = new ObjectMapper(); + @Getter + private final ObjectMapper objectMapper = new ObjectMapper(); public JsonUtil() { // 当属性为null时不参与序列化 @@ -400,10 +401,8 @@ public class JsonUtil { try { notNull(keyClass, "key class is null"); notNull(valueClass, "value class is null"); - JavaType type = - objectMapper - .getTypeFactory() - .constructParametricType(Map.class, keyClass, valueClass); + JavaType type = objectMapper.getTypeFactory().constructParametricType(Map.class, + keyClass, valueClass); return objectMapper.readValue(json, type); } catch (Exception e) { throw new JsonException(e); @@ -503,8 +502,7 @@ public class JsonUtil { } try { JsonNode jsonNode = readTree(string); - return objectMapper - .writerWithDefaultPrettyPrinter() + return objectMapper.writerWithDefaultPrettyPrinter() .writeValueAsString(jsonNode); } catch (Exception e) { return string; @@ -617,10 +615,7 @@ public class JsonUtil { super(cause); } - private JsonException( - String message, - Throwable cause, - boolean enableSuppression, + private JsonException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { super(message, cause, enableSuppression, writableStackTrace); } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/SignatureUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/SignatureUtils.java index 87cb07e67..c22297624 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/SignatureUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/SignatureUtils.java @@ -12,8 +12,8 @@ public class SignatureUtils { return DigestUtils.sha1Hex(psw); } - public static Pair isValidSignature( - String appKey, String appSecret, long timestamp, String signatureToCheck) { + public static Pair isValidSignature(String appKey, String appSecret, + long timestamp, String signatureToCheck) { long currentTimeMillis = System.currentTimeMillis(); if (currentTimeMillis < timestamp) { diff --git a/common/src/main/java/com/tencent/supersonic/common/util/SqlFilterUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/SqlFilterUtils.java index 1d0b6aede..32ac02eaf 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/SqlFilterUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/SqlFilterUtils.java @@ -59,13 +59,11 @@ public class SqlFilterUtils { StringJoiner joiner = new StringJoiner(Constants.AND_UPPER); if (!CollectionUtils.isEmpty(filters)) { - filters.stream() - .forEach( - filter -> { - if (StringUtils.isNotEmpty(dealFilter(filter, isBizName))) { - joiner.add(SPACE + dealFilter(filter, isBizName) + SPACE); - } - }); + filters.stream().forEach(filter -> { + if (StringUtils.isNotEmpty(dealFilter(filter, isBizName))) { + joiner.add(SPACE + dealFilter(filter, isBizName) + SPACE); + } + }); log.debug("getWhereClause, where sql : {}", joiner); return joiner.toString(); } @@ -160,8 +158,8 @@ public class SqlFilterUtils { throw new RuntimeException("criterion.getValue() can not be null"); } StringBuilder whereClause = new StringBuilder(); - whereClause.append( - criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE); + whereClause + .append(criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE); String value = criterion.getValue().toString(); if (criterion.isNeedApostrophe() && !Pattern.matches(pattern, value)) { // like click => 'like%' @@ -170,10 +168,9 @@ public class SqlFilterUtils { } else { // like 'click' => 'like%' - whereClause.append( - Constants.APOSTROPHE - + value.replaceAll(Constants.APOSTROPHE, Constants.PERCENT_SIGN) - + Constants.APOSTROPHE); + whereClause.append(Constants.APOSTROPHE + + value.replaceAll(Constants.APOSTROPHE, Constants.PERCENT_SIGN) + + Constants.APOSTROPHE); } return whereClause.toString(); } @@ -184,8 +181,8 @@ public class SqlFilterUtils { } StringBuilder whereClause = new StringBuilder(); - whereClause.append( - criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE); + whereClause + .append(criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE); List values = (List) criterion.getValue(); whereClause.append(PARENTHESES_START); StringJoiner joiner = new StringJoiner(","); @@ -209,19 +206,12 @@ public class SqlFilterUtils { } if (criterion.isNeedApostrophe()) { - return String.format( - "(%s >= %s and %s <= %s)", - criterion.getColumn(), - valueApostropheLogic(values.get(0).toString()), - criterion.getColumn(), + return String.format("(%s >= %s and %s <= %s)", criterion.getColumn(), + valueApostropheLogic(values.get(0).toString()), criterion.getColumn(), valueApostropheLogic(values.get(1).toString())); } - return String.format( - "(%s >= %s and %s <= %s)", - criterion.getColumn(), - values.get(0).toString(), - criterion.getColumn(), - values.get(1).toString()); + return String.format("(%s >= %s and %s <= %s)", criterion.getColumn(), + values.get(0).toString(), criterion.getColumn(), values.get(1).toString()); } private String singleValueLogic(Criterion criterion) { @@ -229,8 +219,8 @@ public class SqlFilterUtils { throw new RuntimeException("criterion.getValue() can not be null"); } StringBuilder whereClause = new StringBuilder(); - whereClause.append( - criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE); + whereClause + .append(criterion.getColumn() + SPACE + criterion.getOperator().getValue() + SPACE); String value = criterion.getValue().toString(); if (criterion.isNeedApostrophe()) { value = valueApostropheLogic(value); @@ -258,10 +248,7 @@ public class SqlFilterUtils { if (Objects.isNull(criterion) || Objects.isNull(criterion.getValue())) { throw new RuntimeException("criterion.getValue() can not be null"); } - return PARENTHESES_START - + SPACE - + criterion.getValue().toString() - + SPACE + return PARENTHESES_START + SPACE + criterion.getValue().toString() + SPACE + PARENTHESES_END; } } diff --git a/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java index 8ad0c33ec..2b83736e1 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/StringUtil.java @@ -28,7 +28,7 @@ public class StringUtil { * @param v1 * @param v2 * @return value 0 if v1 equal to v2; less than 0 if v1 is less than v2; greater than 0 if v1 is - * greater than v2 + * greater than v2 */ public static int compareVersion(String v1, String v2) { String[] v1s = v1.split("\\."); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/ThreadMdcUtil.java b/common/src/main/java/com/tencent/supersonic/common/util/ThreadMdcUtil.java index acd4ddf40..18559de35 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/ThreadMdcUtil.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/ThreadMdcUtil.java @@ -12,8 +12,8 @@ public class ThreadMdcUtil { } } - public static Callable wrap( - final Callable callable, final Map context) { + public static Callable wrap(final Callable callable, + final Map context) { return () -> { if (context == null) { MDC.clear(); diff --git a/common/src/main/java/com/tencent/supersonic/common/util/YamlUtils.java b/common/src/main/java/com/tencent/supersonic/common/util/YamlUtils.java index cb79c3b3d..cb04c3ba6 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/YamlUtils.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/YamlUtils.java @@ -51,10 +51,8 @@ public class YamlUtils { .disable(YAMLGenerator.Feature.LITERAL_BLOCK_STYLE); try { String yaml = mapper.writeValueAsString(object); - return yaml.replaceAll("\"True\"", "true") - .replaceAll("\"true\"", "true") - .replaceAll("\"false\"", "false") - .replaceAll("\"False\"", "false"); + return yaml.replaceAll("\"True\"", "true").replaceAll("\"true\"", "true") + .replaceAll("\"false\"", "false").replaceAll("\"False\"", "false"); } catch (IOException e) { log.error("", e); } diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java index dd980881c..fcd135c7d 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/ChromaEmbeddingStoreFactory.java @@ -24,11 +24,8 @@ public class ChromaEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { @Override public EmbeddingStore createEmbeddingStore(String collectionName) { - return ChromaEmbeddingStore.builder() - .baseUrl(storeProperties.getBaseUrl()) - .collectionName(collectionName) - .timeout(storeProperties.getTimeout()) - .build(); + return ChromaEmbeddingStore.builder().baseUrl(storeProperties.getBaseUrl()) + .collectionName(collectionName).timeout(storeProperties.getTimeout()).build(); } private static EmbeddingStoreProperties createPropertiesFromConfig( diff --git a/common/src/main/java/dev/langchain4j/chroma/spring/Properties.java b/common/src/main/java/dev/langchain4j/chroma/spring/Properties.java index 26141ff19..cd02a5c46 100644 --- a/common/src/main/java/dev/langchain4j/chroma/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/chroma/spring/Properties.java @@ -12,5 +12,6 @@ public class Properties { static final String PREFIX = "langchain4j.chroma"; - @NestedConfigurationProperty EmbeddingStoreProperties embeddingStore; + @NestedConfigurationProperty + EmbeddingStoreProperties embeddingStore; } diff --git a/common/src/main/java/dev/langchain4j/dashscope/spring/DashscopeAutoConfig.java b/common/src/main/java/dev/langchain4j/dashscope/spring/DashscopeAutoConfig.java index db915d59d..b74408496 100644 --- a/common/src/main/java/dev/langchain4j/dashscope/spring/DashscopeAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/dashscope/spring/DashscopeAutoConfig.java @@ -20,18 +20,15 @@ public class DashscopeAutoConfig { @ConditionalOnProperty(PREFIX + ".chat-model.api-key") QwenChatModel qwenChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getChatModel(); - return QwenChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) + return QwenChatModel.builder().baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) - .modelName(chatModelProperties.getModelName()) - .topP(chatModelProperties.getTopP()) + .modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP()) .topK(chatModelProperties.getTopK()) .enableSearch(chatModelProperties.getEnableSearch()) .seed(chatModelProperties.getSeed()) .repetitionPenalty(chatModelProperties.getRepetitionPenalty()) .temperature(chatModelProperties.getTemperature()) - .stops(chatModelProperties.getStops()) - .maxTokens(chatModelProperties.getMaxTokens()) + .stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens()) .build(); } @@ -39,18 +36,15 @@ public class DashscopeAutoConfig { @ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key") QwenStreamingChatModel qwenStreamingChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); - return QwenStreamingChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) + return QwenStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) - .modelName(chatModelProperties.getModelName()) - .topP(chatModelProperties.getTopP()) + .modelName(chatModelProperties.getModelName()).topP(chatModelProperties.getTopP()) .topK(chatModelProperties.getTopK()) .enableSearch(chatModelProperties.getEnableSearch()) .seed(chatModelProperties.getSeed()) .repetitionPenalty(chatModelProperties.getRepetitionPenalty()) .temperature(chatModelProperties.getTemperature()) - .stops(chatModelProperties.getStops()) - .maxTokens(chatModelProperties.getMaxTokens()) + .stops(chatModelProperties.getStops()).maxTokens(chatModelProperties.getMaxTokens()) .build(); } @@ -58,47 +52,33 @@ public class DashscopeAutoConfig { @ConditionalOnProperty(PREFIX + ".language-model.api-key") QwenLanguageModel qwenLanguageModel(Properties properties) { ChatModelProperties languageModel = properties.getLanguageModel(); - return QwenLanguageModel.builder() - .baseUrl(languageModel.getBaseUrl()) - .apiKey(languageModel.getApiKey()) - .modelName(languageModel.getModelName()) - .topP(languageModel.getTopP()) - .topK(languageModel.getTopK()) - .enableSearch(languageModel.getEnableSearch()) - .seed(languageModel.getSeed()) + return QwenLanguageModel.builder().baseUrl(languageModel.getBaseUrl()) + .apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName()) + .topP(languageModel.getTopP()).topK(languageModel.getTopK()) + .enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed()) .repetitionPenalty(languageModel.getRepetitionPenalty()) - .temperature(languageModel.getTemperature()) - .stops(languageModel.getStops()) - .maxTokens(languageModel.getMaxTokens()) - .build(); + .temperature(languageModel.getTemperature()).stops(languageModel.getStops()) + .maxTokens(languageModel.getMaxTokens()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".streaming-language-model.api-key") QwenStreamingLanguageModel qwenStreamingLanguageModel(Properties properties) { ChatModelProperties languageModel = properties.getStreamingLanguageModel(); - return QwenStreamingLanguageModel.builder() - .baseUrl(languageModel.getBaseUrl()) - .apiKey(languageModel.getApiKey()) - .modelName(languageModel.getModelName()) - .topP(languageModel.getTopP()) - .topK(languageModel.getTopK()) - .enableSearch(languageModel.getEnableSearch()) - .seed(languageModel.getSeed()) + return QwenStreamingLanguageModel.builder().baseUrl(languageModel.getBaseUrl()) + .apiKey(languageModel.getApiKey()).modelName(languageModel.getModelName()) + .topP(languageModel.getTopP()).topK(languageModel.getTopK()) + .enableSearch(languageModel.getEnableSearch()).seed(languageModel.getSeed()) .repetitionPenalty(languageModel.getRepetitionPenalty()) - .temperature(languageModel.getTemperature()) - .stops(languageModel.getStops()) - .maxTokens(languageModel.getMaxTokens()) - .build(); + .temperature(languageModel.getTemperature()).stops(languageModel.getStops()) + .maxTokens(languageModel.getMaxTokens()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".embedding-model.api-key") QwenEmbeddingModel qwenEmbeddingModel(Properties properties) { EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); - return QwenEmbeddingModel.builder() - .apiKey(embeddingModelProperties.getApiKey()) - .modelName(embeddingModelProperties.getModelName()) - .build(); + return QwenEmbeddingModel.builder().apiKey(embeddingModelProperties.getApiKey()) + .modelName(embeddingModelProperties.getModelName()).build(); } } diff --git a/common/src/main/java/dev/langchain4j/dashscope/spring/Properties.java b/common/src/main/java/dev/langchain4j/dashscope/spring/Properties.java index 53852c433..b232a7549 100644 --- a/common/src/main/java/dev/langchain4j/dashscope/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/dashscope/spring/Properties.java @@ -12,13 +12,18 @@ public class Properties { static final String PREFIX = "langchain4j.dashscope"; - @NestedConfigurationProperty ChatModelProperties chatModel; + @NestedConfigurationProperty + ChatModelProperties chatModel; - @NestedConfigurationProperty ChatModelProperties streamingChatModel; + @NestedConfigurationProperty + ChatModelProperties streamingChatModel; - @NestedConfigurationProperty ChatModelProperties languageModel; + @NestedConfigurationProperty + ChatModelProperties languageModel; - @NestedConfigurationProperty ChatModelProperties streamingLanguageModel; + @NestedConfigurationProperty + ChatModelProperties streamingLanguageModel; - @NestedConfigurationProperty EmbeddingModelProperties embeddingModel; + @NestedConfigurationProperty + EmbeddingModelProperties embeddingModel; } diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java index 56c35de95..40b16802e 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/InMemoryEmbeddingStoreFactory.java @@ -74,8 +74,8 @@ public class InMemoryEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { if (MapUtils.isEmpty(super.collectionNameToStore)) { return; } - for (Map.Entry> entry : - collectionNameToStore.entrySet()) { + for (Map.Entry> entry : collectionNameToStore + .entrySet()) { Path filePath = getPersistPath(entry.getKey()); if (Objects.isNull(filePath)) { continue; diff --git a/common/src/main/java/dev/langchain4j/inmemory/spring/Properties.java b/common/src/main/java/dev/langchain4j/inmemory/spring/Properties.java index 281496c8e..87bc971a9 100644 --- a/common/src/main/java/dev/langchain4j/inmemory/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/inmemory/spring/Properties.java @@ -12,7 +12,9 @@ public class Properties { static final String PREFIX = "langchain4j.in-memory"; - @NestedConfigurationProperty EmbeddingStoreProperties embeddingStore; + @NestedConfigurationProperty + EmbeddingStoreProperties embeddingStore; - @NestedConfigurationProperty EmbeddingModelProperties embeddingModel; + @NestedConfigurationProperty + EmbeddingModelProperties embeddingModel; } diff --git a/common/src/main/java/dev/langchain4j/localai/spring/LocalAiAutoConfig.java b/common/src/main/java/dev/langchain4j/localai/spring/LocalAiAutoConfig.java index 817451b1a..81aeddc55 100644 --- a/common/src/main/java/dev/langchain4j/localai/spring/LocalAiAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/localai/spring/LocalAiAutoConfig.java @@ -20,70 +20,58 @@ public class LocalAiAutoConfig { @ConditionalOnProperty(PREFIX + ".chat-model.base-url") LocalAiChatModel localAiChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getChatModel(); - return LocalAiChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) + return LocalAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl()) .modelName(chatModelProperties.getModelName()) .temperature(chatModelProperties.getTemperature()) - .topP(chatModelProperties.getTopP()) - .maxRetries(chatModelProperties.getMaxRetries()) + .topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries()) .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .build(); + .logResponses(chatModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".streaming-chat-model.base-url") LocalAiStreamingChatModel localAiStreamingChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); - return LocalAiStreamingChatModel.builder() - .temperature(chatModelProperties.getTemperature()) - .topP(chatModelProperties.getTopP()) - .baseUrl(chatModelProperties.getBaseUrl()) + return LocalAiStreamingChatModel.builder().temperature(chatModelProperties.getTemperature()) + .topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl()) .modelName(chatModelProperties.getModelName()) .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .build(); + .logResponses(chatModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".language-model.base-url") LocalAiLanguageModel localAiLanguageModel(Properties properties) { LanguageModelProperties languageModelProperties = properties.getLanguageModel(); - return LocalAiLanguageModel.builder() - .topP(languageModelProperties.getTopP()) + return LocalAiLanguageModel.builder().topP(languageModelProperties.getTopP()) .baseUrl(languageModelProperties.getBaseUrl()) .modelName(languageModelProperties.getModelName()) .temperature(languageModelProperties.getTemperature()) .maxRetries(languageModelProperties.getMaxRetries()) .logRequests(languageModelProperties.getLogRequests()) - .logResponses(languageModelProperties.getLogResponses()) - .build(); + .logResponses(languageModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".streaming-language-model.base-url") LocalAiStreamingLanguageModel localAiStreamingLanguageModel(Properties properties) { LanguageModelProperties languageModelProperties = properties.getStreamingLanguageModel(); - return LocalAiStreamingLanguageModel.builder() - .topP(languageModelProperties.getTopP()) + return LocalAiStreamingLanguageModel.builder().topP(languageModelProperties.getTopP()) .baseUrl(languageModelProperties.getBaseUrl()) .modelName(languageModelProperties.getModelName()) .temperature(languageModelProperties.getTemperature()) .logRequests(languageModelProperties.getLogRequests()) - .logResponses(languageModelProperties.getLogResponses()) - .build(); + .logResponses(languageModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".embedding-model.base-url") LocalAiEmbeddingModel localAiEmbeddingModel(Properties properties) { EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); - return LocalAiEmbeddingModel.builder() - .baseUrl(embeddingModelProperties.getBaseUrl()) + return LocalAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl()) .modelName(embeddingModelProperties.getModelName()) .maxRetries(embeddingModelProperties.getMaxRetries()) .logRequests(embeddingModelProperties.getLogRequests()) - .logResponses(embeddingModelProperties.getLogResponses()) - .build(); + .logResponses(embeddingModelProperties.getLogResponses()).build(); } } diff --git a/common/src/main/java/dev/langchain4j/localai/spring/Properties.java b/common/src/main/java/dev/langchain4j/localai/spring/Properties.java index e8b399ead..61ae589b7 100644 --- a/common/src/main/java/dev/langchain4j/localai/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/localai/spring/Properties.java @@ -12,13 +12,18 @@ public class Properties { static final String PREFIX = "langchain4j.local-ai"; - @NestedConfigurationProperty ChatModelProperties chatModel; + @NestedConfigurationProperty + ChatModelProperties chatModel; - @NestedConfigurationProperty ChatModelProperties streamingChatModel; + @NestedConfigurationProperty + ChatModelProperties streamingChatModel; - @NestedConfigurationProperty LanguageModelProperties languageModel; + @NestedConfigurationProperty + LanguageModelProperties languageModel; - @NestedConfigurationProperty LanguageModelProperties streamingLanguageModel; + @NestedConfigurationProperty + LanguageModelProperties streamingLanguageModel; - @NestedConfigurationProperty EmbeddingModelProperties embeddingModel; + @NestedConfigurationProperty + EmbeddingModelProperties embeddingModel; } diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java index 1f2d15f0f..eca6aa7d5 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/MilvusEmbeddingStoreFactory.java @@ -29,21 +29,15 @@ public class MilvusEmbeddingStoreFactory extends BaseEmbeddingStoreFactory { @Override public EmbeddingStore createEmbeddingStore(String collectionName) { - return MilvusEmbeddingStore.builder() - .host(storeProperties.getHost()) - .port(storeProperties.getPort()) - .collectionName(collectionName) - .dimension(storeProperties.getDimension()) - .indexType(storeProperties.getIndexType()) - .metricType(storeProperties.getMetricType()) - .uri(storeProperties.getUri()) - .token(storeProperties.getToken()) - .username(storeProperties.getUsername()) + return MilvusEmbeddingStore.builder().host(storeProperties.getHost()) + .port(storeProperties.getPort()).collectionName(collectionName) + .dimension(storeProperties.getDimension()).indexType(storeProperties.getIndexType()) + .metricType(storeProperties.getMetricType()).uri(storeProperties.getUri()) + .token(storeProperties.getToken()).username(storeProperties.getUsername()) .password(storeProperties.getPassword()) .consistencyLevel(storeProperties.getConsistencyLevel()) .retrieveEmbeddingsOnSearch(storeProperties.getRetrieveEmbeddingsOnSearch()) .autoFlushOnInsert(storeProperties.getAutoFlushOnInsert()) - .databaseName(storeProperties.getDatabaseName()) - .build(); + .databaseName(storeProperties.getDatabaseName()).build(); } } diff --git a/common/src/main/java/dev/langchain4j/milvus/spring/Properties.java b/common/src/main/java/dev/langchain4j/milvus/spring/Properties.java index 36ea54e24..7584e7942 100644 --- a/common/src/main/java/dev/langchain4j/milvus/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/milvus/spring/Properties.java @@ -12,5 +12,6 @@ public class Properties { static final String PREFIX = "langchain4j.milvus"; - @NestedConfigurationProperty EmbeddingStoreProperties embeddingStore; + @NestedConfigurationProperty + EmbeddingStoreProperties embeddingStore; } diff --git a/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java b/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java index a1ba91439..45db6c414 100644 --- a/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java +++ b/common/src/main/java/dev/langchain4j/model/embedding/S2OnnxEmbeddingModel.java @@ -13,10 +13,10 @@ import java.util.Objects; /** * An embedding model that runs within your Java application's process. Any BERT-based model (e.g., * from HuggingFace) can be used, as long as it is in ONNX format. Information on how to convert - * models into ONNX format can be found here. - * Many models already converted to ONNX format are available here. Copy from + * models into ONNX format can be found here. Many + * models already converted to ONNX format are available + * here. Copy from * dev.langchain4j.model.embedding.OnnxEmbeddingModel. */ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { @@ -28,9 +28,8 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { if (shouldReloadModel(pathToModel, vocabularyPath)) { synchronized (S2OnnxEmbeddingModel.class) { if (shouldReloadModel(pathToModel, vocabularyPath)) { - URL resource = - AbstractInProcessEmbeddingModel.class.getResource( - "/bert-vocabulary-en.txt"); + URL resource = AbstractInProcessEmbeddingModel.class + .getResource("/bert-vocabulary-en.txt"); if (StringUtils.isNotBlank(vocabularyPath)) { try { resource = Paths.get(vocabularyPath).toUri().toURL(); @@ -56,15 +55,14 @@ public class S2OnnxEmbeddingModel extends AbstractInProcessEmbeddingModel { } private static boolean shouldReloadModel(String pathToModel, String vocabularyPath) { - return cachedModel == null - || !Objects.equals(cachedModelPath, pathToModel) + return cachedModel == null || !Objects.equals(cachedModelPath, pathToModel) || !Objects.equals(cachedVocabularyPath, vocabularyPath); } static OnnxBertBiEncoder loadFromFileSystem(Path pathToModel, URL vocabularyFile) { try { - return new OnnxBertBiEncoder( - Files.newInputStream(pathToModel), vocabularyFile, PoolingMode.MEAN); + return new OnnxBertBiEncoder(Files.newInputStream(pathToModel), vocabularyFile, + PoolingMode.MEAN); } catch (IOException e) { throw new RuntimeException(e); } diff --git a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java index efec94a90..6bd56fae8 100644 --- a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java +++ b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModel.java @@ -60,8 +60,8 @@ import static java.util.Collections.singletonList; /** * Represents an OpenAI language model with a chat completion interface, such as gpt-3.5-turbo and - * gpt-4. You can find description of parameters here. + * gpt-4. You can find description of parameters + * here. */ @Slf4j public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { @@ -88,32 +88,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { private final List listeners; @Builder - public OpenAiChatModel( - String baseUrl, - String apiKey, - String organizationId, - String modelName, - Double temperature, - Double topP, - List stop, - Integer maxTokens, - Double presencePenalty, - Double frequencyPenalty, - Map logitBias, - String responseFormat, - Boolean strictJsonSchema, - Integer seed, - String user, - Boolean strictTools, - Boolean parallelToolCalls, - Duration timeout, - Integer maxRetries, - Proxy proxy, - Boolean logRequests, - Boolean logResponses, - Tokenizer tokenizer, - Map customHeaders, - List listeners) { + public OpenAiChatModel(String baseUrl, String apiKey, String organizationId, String modelName, + Double temperature, Double topP, List stop, Integer maxTokens, + Double presencePenalty, Double frequencyPenalty, Map logitBias, + String responseFormat, Boolean strictJsonSchema, Integer seed, String user, + Boolean strictTools, Boolean parallelToolCalls, Duration timeout, Integer maxRetries, + Proxy proxy, Boolean logRequests, Boolean logResponses, Tokenizer tokenizer, + Map customHeaders, List listeners) { baseUrl = getOrDefault(baseUrl, OPENAI_URL); if (OPENAI_DEMO_API_KEY.equals(apiKey)) { @@ -123,21 +104,11 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { timeout = getOrDefault(timeout, ofSeconds(60)); - this.client = - OpenAiClient.builder() - .openAiApiKey(apiKey) - .baseUrl(baseUrl) - .organizationId(organizationId) - .callTimeout(timeout) - .connectTimeout(timeout) - .readTimeout(timeout) - .writeTimeout(timeout) - .proxy(proxy) - .logRequests(logRequests) - .logResponses(logResponses) - .userAgent(DEFAULT_USER_AGENT) - .customHeaders(customHeaders) - .build(); + this.client = OpenAiClient.builder().openAiApiKey(apiKey).baseUrl(baseUrl) + .organizationId(organizationId).callTimeout(timeout).connectTimeout(timeout) + .readTimeout(timeout).writeTimeout(timeout).proxy(proxy).logRequests(logRequests) + .logResponses(logResponses).userAgent(DEFAULT_USER_AGENT) + .customHeaders(customHeaders).build(); this.modelName = getOrDefault(modelName, GPT_3_5_TURBO); this.temperature = getOrDefault(temperature, 0.7); this.topP = topP; @@ -146,14 +117,10 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { this.presencePenalty = presencePenalty; this.frequencyPenalty = frequencyPenalty; this.logitBias = logitBias; - this.responseFormat = - responseFormat == null - ? null - : ResponseFormat.builder() - .type( - ResponseFormatType.valueOf( - responseFormat.toUpperCase(Locale.ROOT))) - .build(); + this.responseFormat = responseFormat == null ? null + : ResponseFormat.builder() + .type(ResponseFormatType.valueOf(responseFormat.toUpperCase(Locale.ROOT))) + .build(); this.strictJsonSchema = getOrDefault(strictJsonSchema, false); this.seed = seed; this.user = user; @@ -183,61 +150,44 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { } @Override - public Response generate( - List messages, List toolSpecifications) { + public Response generate(List messages, + List toolSpecifications) { return generate(messages, toolSpecifications, null, this.responseFormat); } @Override - public Response generate( - List messages, ToolSpecification toolSpecification) { - return generate( - messages, singletonList(toolSpecification), toolSpecification, this.responseFormat); + public Response generate(List messages, + ToolSpecification toolSpecification) { + return generate(messages, singletonList(toolSpecification), toolSpecification, + this.responseFormat); } @Override public ChatResponse chat(ChatRequest request) { Response response = - generate( - request.messages(), - request.toolSpecifications(), - null, + generate(request.messages(), request.toolSpecifications(), null, getOrDefault( toOpenAiResponseFormat(request.responseFormat(), strictJsonSchema), this.responseFormat)); - return ChatResponse.builder() - .aiMessage(response.content()) - .tokenUsage(response.tokenUsage()) - .finishReason(response.finishReason()) - .build(); + return ChatResponse.builder().aiMessage(response.content()) + .tokenUsage(response.tokenUsage()).finishReason(response.finishReason()).build(); } - private Response generate( - List messages, - List toolSpecifications, - ToolSpecification toolThatMustBeExecuted, + private Response generate(List messages, + List toolSpecifications, ToolSpecification toolThatMustBeExecuted, ResponseFormat responseFormat) { - if (responseFormat != null - && responseFormat.type() == JSON_SCHEMA + if (responseFormat != null && responseFormat.type() == JSON_SCHEMA && responseFormat.jsonSchema() == null) { responseFormat = null; } - ChatCompletionRequest.Builder requestBuilder = - ChatCompletionRequest.builder() - .model(modelName) - .messages(toOpenAiMessages(messages)) - .topP(topP) - .stop(stop) - .maxTokens(maxTokens) - .presencePenalty(presencePenalty) - .frequencyPenalty(frequencyPenalty) - .logitBias(logitBias) - .responseFormat(responseFormat) - .seed(seed) - .user(user) - .parallelToolCalls(parallelToolCalls); + ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder() + .model(modelName).messages(toOpenAiMessages(messages)).topP(topP).stop(stop) + .maxTokens(maxTokens).presencePenalty(presencePenalty) + .frequencyPenalty(frequencyPenalty).logitBias(logitBias) + .responseFormat(responseFormat).seed(seed).user(user) + .parallelToolCalls(parallelToolCalls); if (!(baseUrl.contains(ZHIPU))) { requestBuilder.temperature(temperature); @@ -257,40 +207,33 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { Map attributes = new ConcurrentHashMap<>(); ChatModelRequestContext requestContext = new ChatModelRequestContext(modelListenerRequest, attributes); - listeners.forEach( - listener -> { - try { - listener.onRequest(requestContext); - } catch (Exception e) { - log.warn("Exception while calling model listener", e); - } - }); + listeners.forEach(listener -> { + try { + listener.onRequest(requestContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); try { ChatCompletionResponse chatCompletionResponse = withRetry(() -> client.chatCompletion(request).execute(), maxRetries); - Response response = - Response.from( - aiMessageFrom(chatCompletionResponse), - tokenUsageFrom(chatCompletionResponse.usage()), - finishReasonFrom( - chatCompletionResponse.choices().get(0).finishReason())); + Response response = Response.from(aiMessageFrom(chatCompletionResponse), + tokenUsageFrom(chatCompletionResponse.usage()), + finishReasonFrom(chatCompletionResponse.choices().get(0).finishReason())); - ChatModelResponse modelListenerResponse = - createModelListenerResponse( - chatCompletionResponse.id(), chatCompletionResponse.model(), response); - ChatModelResponseContext responseContext = - new ChatModelResponseContext( - modelListenerResponse, modelListenerRequest, attributes); - listeners.forEach( - listener -> { - try { - listener.onResponse(responseContext); - } catch (Exception e) { - log.warn("Exception while calling model listener", e); - } - }); + ChatModelResponse modelListenerResponse = createModelListenerResponse( + chatCompletionResponse.id(), chatCompletionResponse.model(), response); + ChatModelResponseContext responseContext = new ChatModelResponseContext( + modelListenerResponse, modelListenerRequest, attributes); + listeners.forEach(listener -> { + try { + listener.onResponse(responseContext); + } catch (Exception e) { + log.warn("Exception while calling model listener", e); + } + }); return response; } catch (RuntimeException e) { @@ -305,14 +248,13 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { ChatModelErrorContext errorContext = new ChatModelErrorContext(error, modelListenerRequest, null, attributes); - listeners.forEach( - listener -> { - try { - listener.onError(errorContext); - } catch (Exception e2) { - log.warn("Exception while calling model listener", e2); - } - }); + listeners.forEach(listener -> { + try { + listener.onError(errorContext); + } catch (Exception e2) { + log.warn("Exception while calling model listener", e2); + } + }); throw e; } @@ -328,8 +270,8 @@ public class OpenAiChatModel implements ChatLanguageModel, TokenCountEstimator { } public static OpenAiChatModelBuilder builder() { - for (OpenAiChatModelBuilderFactory factory : - loadFactories(OpenAiChatModelBuilderFactory.class)) { + for (OpenAiChatModelBuilderFactory factory : loadFactories( + OpenAiChatModelBuilderFactory.class)) { return factory.get(); } return new OpenAiChatModelBuilder(); diff --git a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java index 56db8f911..06842f6da 100644 --- a/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java +++ b/common/src/main/java/dev/langchain4j/model/openai/OpenAiChatModelName.java @@ -3,9 +3,8 @@ package dev.langchain4j.model.openai; public enum OpenAiChatModelName { GPT_3_5_TURBO("gpt-3.5-turbo"), // alias @Deprecated - GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"), - GPT_3_5_TURBO_1106("gpt-3.5-turbo-1106"), - GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"), + GPT_3_5_TURBO_0613("gpt-3.5-turbo-0613"), GPT_3_5_TURBO_1106( + "gpt-3.5-turbo-1106"), GPT_3_5_TURBO_0125("gpt-3.5-turbo-0125"), GPT_3_5_TURBO_16K("gpt-3.5-turbo-16k"), // alias @Deprecated @@ -13,22 +12,18 @@ public enum OpenAiChatModelName { GPT_4("gpt-4"), // alias @Deprecated - GPT_4_0314("gpt-4-0314"), - GPT_4_0613("gpt-4-0613"), + GPT_4_0314("gpt-4-0314"), GPT_4_0613("gpt-4-0613"), GPT_4_TURBO_PREVIEW("gpt-4-turbo-preview"), // alias - GPT_4_1106_PREVIEW("gpt-4-1106-preview"), - GPT_4_0125_PREVIEW("gpt-4-0125-preview"), + GPT_4_1106_PREVIEW("gpt-4-1106-preview"), GPT_4_0125_PREVIEW("gpt-4-0125-preview"), GPT_4_32K("gpt-4-32k"), // alias - GPT_4_32K_0314("gpt-4-32k-0314"), - GPT_4_32K_0613("gpt-4-32k-0613"), + GPT_4_32K_0314("gpt-4-32k-0314"), GPT_4_32K_0613("gpt-4-32k-0613"), @Deprecated GPT_4_VISION_PREVIEW("gpt-4-vision-preview"), - GPT_4_O("gpt-4o"), - GPT_4_O_MINI("gpt-4o-mini"); + GPT_4_O("gpt-4o"), GPT_4_O_MINI("gpt-4o-mini"); private final String stringValue; diff --git a/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java b/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java index 6ea276723..4e153f9e5 100644 --- a/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java +++ b/common/src/main/java/dev/langchain4j/model/zhipu/ChatCompletionModel.java @@ -1,9 +1,7 @@ package dev.langchain4j.model.zhipu; public enum ChatCompletionModel { - GLM_4("glm-4"), - GLM_3_TURBO("glm-3-turbo"), - CHATGLM_TURBO("chatglm_turbo"); + GLM_4("glm-4"), GLM_3_TURBO("glm-3-turbo"), CHATGLM_TURBO("chatglm_turbo"); private final String value; diff --git a/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java b/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java index 3c1436d33..3d14e4045 100644 --- a/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java +++ b/common/src/main/java/dev/langchain4j/model/zhipu/ZhipuAiChatModel.java @@ -27,8 +27,8 @@ import static java.util.Collections.singletonList; /** * Represents an ZhipuAi language model with a chat completion interface, such as glm-3-turbo and - * glm-4. You can find description of parameters here. + * glm-4. You can find description of parameters + * here. */ public class ZhipuAiChatModel implements ChatLanguageModel { @@ -41,15 +41,8 @@ public class ZhipuAiChatModel implements ChatLanguageModel { private final ZhipuAiClient client; @Builder - public ZhipuAiChatModel( - String baseUrl, - String apiKey, - Double temperature, - Double topP, - String model, - Integer maxRetries, - Integer maxToken, - Boolean logRequests, + public ZhipuAiChatModel(String baseUrl, String apiKey, Double temperature, Double topP, + String model, Integer maxRetries, Integer maxToken, Boolean logRequests, Boolean logResponses) { this.baseUrl = getOrDefault(baseUrl, "https://open.bigmodel.cn/"); this.temperature = getOrDefault(temperature, 0.7); @@ -57,18 +50,14 @@ public class ZhipuAiChatModel implements ChatLanguageModel { this.model = getOrDefault(model, ChatCompletionModel.GLM_4.toString()); this.maxRetries = getOrDefault(maxRetries, 3); this.maxToken = getOrDefault(maxToken, 512); - this.client = - ZhipuAiClient.builder() - .baseUrl(this.baseUrl) - .apiKey(apiKey) - .logRequests(getOrDefault(logRequests, false)) - .logResponses(getOrDefault(logResponses, false)) - .build(); + this.client = ZhipuAiClient.builder().baseUrl(this.baseUrl).apiKey(apiKey) + .logRequests(getOrDefault(logRequests, false)) + .logResponses(getOrDefault(logResponses, false)).build(); } public static ZhipuAiChatModelBuilder builder() { - for (ZhipuAiChatModelBuilderFactory factories : - loadFactories(ZhipuAiChatModelBuilderFactory.class)) { + for (ZhipuAiChatModelBuilderFactory factories : loadFactories( + ZhipuAiChatModelBuilderFactory.class)) { return factories.get(); } return new ZhipuAiChatModelBuilder(); @@ -80,15 +69,13 @@ public class ZhipuAiChatModel implements ChatLanguageModel { } @Override - public Response generate( - List messages, List toolSpecifications) { + public Response generate(List messages, + List toolSpecifications) { ensureNotEmpty(messages, "messages"); ChatCompletionRequest.Builder requestBuilder = ChatCompletionRequest.builder().model(this.model).maxTokens(maxToken).stream(false) - .topP(topP) - .toolChoice(AUTO) - .messages(toZhipuAiMessages(messages)); + .topP(topP).toolChoice(AUTO).messages(toZhipuAiMessages(messages)); if (!isNullOrEmpty(toolSpecifications)) { requestBuilder.tools(toTools(toolSpecifications)); @@ -96,17 +83,15 @@ public class ZhipuAiChatModel implements ChatLanguageModel { ChatCompletionResponse response = withRetry(() -> client.chatCompletion(requestBuilder.build()), maxRetries); - return Response.from( - aiMessageFrom(response), - tokenUsageFrom(response.getUsage()), + return Response.from(aiMessageFrom(response), tokenUsageFrom(response.getUsage()), finishReasonFrom(response.getChoices().get(0).getFinishReason())); } @Override - public Response generate( - List messages, ToolSpecification toolSpecification) { - return generate( - messages, toolSpecification != null ? singletonList(toolSpecification) : null); + public Response generate(List messages, + ToolSpecification toolSpecification) { + return generate(messages, + toolSpecification != null ? singletonList(toolSpecification) : null); } public static class ZhipuAiChatModelBuilder { diff --git a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java index 5defb178b..1637c3ac9 100644 --- a/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/AzureModelFactory.java @@ -20,36 +20,27 @@ public class AzureModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - AzureOpenAiChatModel.Builder builder = - AzureOpenAiChatModel.builder() - .endpoint(modelConfig.getBaseUrl()) - .apiKey(modelConfig.getApiKey()) - .deploymentName(modelConfig.getModelName()) - .temperature(modelConfig.getTemperature()) - .maxRetries(modelConfig.getMaxRetries()) - .topP(modelConfig.getTopP()) - .timeout( - Duration.ofSeconds( - modelConfig.getTimeOut() == null - ? 0L - : modelConfig.getTimeOut())) - .logRequestsAndResponses( - modelConfig.getLogRequests() != null - && modelConfig.getLogResponses()); + AzureOpenAiChatModel.Builder builder = AzureOpenAiChatModel.builder() + .endpoint(modelConfig.getBaseUrl()).apiKey(modelConfig.getApiKey()) + .deploymentName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()).maxRetries(modelConfig.getMaxRetries()) + .topP(modelConfig.getTopP()) + .timeout(Duration.ofSeconds( + modelConfig.getTimeOut() == null ? 0L : modelConfig.getTimeOut())) + .logRequestsAndResponses( + modelConfig.getLogRequests() != null && modelConfig.getLogResponses()); return builder.build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { AzureOpenAiEmbeddingModel.Builder builder = - AzureOpenAiEmbeddingModel.builder() - .endpoint(embeddingModelConfig.getBaseUrl()) + AzureOpenAiEmbeddingModel.builder().endpoint(embeddingModelConfig.getBaseUrl()) .apiKey(embeddingModelConfig.getApiKey()) .deploymentName(embeddingModelConfig.getModelName()) .maxRetries(embeddingModelConfig.getMaxRetries()) - .logRequestsAndResponses( - embeddingModelConfig.getLogRequests() != null - && embeddingModelConfig.getLogResponses()); + .logRequestsAndResponses(embeddingModelConfig.getLogRequests() != null + && embeddingModelConfig.getLogResponses()); return builder.build(); } diff --git a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java index e2447af14..22529ef73 100644 --- a/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/DashscopeModelFactory.java @@ -19,25 +19,17 @@ public class DashscopeModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return QwenChatModel.builder() - .baseUrl(modelConfig.getBaseUrl()) - .apiKey(modelConfig.getApiKey()) - .modelName(modelConfig.getModelName()) - .temperature( - modelConfig.getTemperature() == null - ? 0L - : modelConfig.getTemperature().floatValue()) - .topP(modelConfig.getTopP()) - .enableSearch(modelConfig.getEnableSearch()) - .build(); + return QwenChatModel.builder().baseUrl(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()).modelName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature() == null ? 0L + : modelConfig.getTemperature().floatValue()) + .topP(modelConfig.getTopP()).enableSearch(modelConfig.getEnableSearch()).build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { - return QwenEmbeddingModel.builder() - .apiKey(embeddingModelConfig.getApiKey()) - .modelName(embeddingModelConfig.getModelName()) - .build(); + return QwenEmbeddingModel.builder().apiKey(embeddingModelConfig.getApiKey()) + .modelName(embeddingModelConfig.getModelName()).build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java index 81c40b9b2..9170efd4e 100644 --- a/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/LocalAiModelFactory.java @@ -19,27 +19,20 @@ public class LocalAiModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return LocalAiChatModel.builder() - .baseUrl(modelConfig.getBaseUrl()) - .modelName(modelConfig.getModelName()) - .temperature(modelConfig.getTemperature()) - .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) - .topP(modelConfig.getTopP()) + return LocalAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) + .modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP()) .logRequests(modelConfig.getLogRequests()) - .logResponses(modelConfig.getLogResponses()) - .maxRetries(modelConfig.getMaxRetries()) + .logResponses(modelConfig.getLogResponses()).maxRetries(modelConfig.getMaxRetries()) .build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) { - return LocalAiEmbeddingModel.builder() - .baseUrl(embeddingModel.getBaseUrl()) - .modelName(embeddingModel.getModelName()) - .maxRetries(embeddingModel.getMaxRetries()) + return LocalAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl()) + .modelName(embeddingModel.getModelName()).maxRetries(embeddingModel.getMaxRetries()) .logRequests(embeddingModel.getLogRequests()) - .logResponses(embeddingModel.getLogResponses()) - .build(); + .logResponses(embeddingModel.getLogResponses()).build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java index fdce1497b..592bfe8d1 100644 --- a/common/src/main/java/dev/langchain4j/provider/ModelProvider.java +++ b/common/src/main/java/dev/langchain4j/provider/ModelProvider.java @@ -25,8 +25,7 @@ public class ModelProvider { } public static ChatLanguageModel getChatModel(ChatModelConfig modelConfig) { - if (modelConfig == null - || StringUtils.isBlank(modelConfig.getProvider()) + if (modelConfig == null || StringUtils.isBlank(modelConfig.getProvider()) || StringUtils.isBlank(modelConfig.getBaseUrl())) { ChatModelParameterConfig parameterConfig = ContextUtils.getBean(ChatModelParameterConfig.class); diff --git a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java index c34b4882e..77b54dddf 100644 --- a/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OllamaModelFactory.java @@ -21,27 +21,20 @@ public class OllamaModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return OllamaChatModel.builder() - .baseUrl(modelConfig.getBaseUrl()) - .modelName(modelConfig.getModelName()) - .temperature(modelConfig.getTemperature()) - .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) - .topP(modelConfig.getTopP()) - .maxRetries(modelConfig.getMaxRetries()) - .logRequests(modelConfig.getLogRequests()) - .logResponses(modelConfig.getLogResponses()) - .build(); + return OllamaChatModel.builder().baseUrl(modelConfig.getBaseUrl()) + .modelName(modelConfig.getModelName()).temperature(modelConfig.getTemperature()) + .timeout(Duration.ofSeconds(modelConfig.getTimeOut())).topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()).build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { - return OllamaEmbeddingModel.builder() - .baseUrl(embeddingModelConfig.getBaseUrl()) + return OllamaEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) .modelName(embeddingModelConfig.getModelName()) .maxRetries(embeddingModelConfig.getMaxRetries()) .logRequests(embeddingModelConfig.getLogRequests()) - .logResponses(embeddingModelConfig.getLogResponses()) - .build(); + .logResponses(embeddingModelConfig.getLogResponses()).build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java index 02815cc5c..2bd90eb55 100644 --- a/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/OpenAiModelFactory.java @@ -21,29 +21,22 @@ public class OpenAiModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return OpenAiChatModel.builder() - .baseUrl(modelConfig.getBaseUrl()) - .modelName(modelConfig.getModelName()) - .apiKey(modelConfig.keyDecrypt()) - .temperature(modelConfig.getTemperature()) - .topP(modelConfig.getTopP()) + return OpenAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) + .modelName(modelConfig.getModelName()).apiKey(modelConfig.keyDecrypt()) + .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP()) .maxRetries(modelConfig.getMaxRetries()) .timeout(Duration.ofSeconds(modelConfig.getTimeOut())) .logRequests(modelConfig.getLogRequests()) - .logResponses(modelConfig.getLogResponses()) - .build(); + .logResponses(modelConfig.getLogResponses()).build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModel) { - return OpenAiEmbeddingModel.builder() - .baseUrl(embeddingModel.getBaseUrl()) - .apiKey(embeddingModel.getApiKey()) - .modelName(embeddingModel.getModelName()) + return OpenAiEmbeddingModel.builder().baseUrl(embeddingModel.getBaseUrl()) + .apiKey(embeddingModel.getApiKey()).modelName(embeddingModel.getModelName()) .maxRetries(embeddingModel.getMaxRetries()) .logRequests(embeddingModel.getLogRequests()) - .logResponses(embeddingModel.getLogResponses()) - .build(); + .logResponses(embeddingModel.getLogResponses()).build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java index 10db2ee1c..28f8f3d3e 100644 --- a/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/QianfanModelFactory.java @@ -21,31 +21,23 @@ public class QianfanModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return QianfanChatModel.builder() - .baseUrl(modelConfig.getBaseUrl()) - .apiKey(modelConfig.getApiKey()) - .secretKey(modelConfig.getSecretKey()) - .endpoint(modelConfig.getEndpoint()) - .modelName(modelConfig.getModelName()) - .temperature(modelConfig.getTemperature()) - .topP(modelConfig.getTopP()) - .maxRetries(modelConfig.getMaxRetries()) - .logRequests(modelConfig.getLogRequests()) - .logResponses(modelConfig.getLogResponses()) - .build(); + return QianfanChatModel.builder().baseUrl(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()).secretKey(modelConfig.getSecretKey()) + .endpoint(modelConfig.getEndpoint()).modelName(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()).build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { - return QianfanEmbeddingModel.builder() - .baseUrl(embeddingModelConfig.getBaseUrl()) + return QianfanEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) .apiKey(embeddingModelConfig.getApiKey()) .secretKey(embeddingModelConfig.getSecretKey()) .modelName(embeddingModelConfig.getModelName()) .maxRetries(embeddingModelConfig.getMaxRetries()) .logRequests(embeddingModelConfig.getLogRequests()) - .logResponses(embeddingModelConfig.getLogResponses()) - .build(); + .logResponses(embeddingModelConfig.getLogResponses()).build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java index 0f004d424..89bd8c7c1 100644 --- a/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java +++ b/common/src/main/java/dev/langchain4j/provider/ZhipuModelFactory.java @@ -19,28 +19,20 @@ public class ZhipuModelFactory implements ModelFactory, InitializingBean { @Override public ChatLanguageModel createChatModel(ChatModelConfig modelConfig) { - return ZhipuAiChatModel.builder() - .baseUrl(modelConfig.getBaseUrl()) - .apiKey(modelConfig.getApiKey()) - .model(modelConfig.getModelName()) - .temperature(modelConfig.getTemperature()) - .topP(modelConfig.getTopP()) - .maxRetries(modelConfig.getMaxRetries()) - .logRequests(modelConfig.getLogRequests()) - .logResponses(modelConfig.getLogResponses()) - .build(); + return ZhipuAiChatModel.builder().baseUrl(modelConfig.getBaseUrl()) + .apiKey(modelConfig.getApiKey()).model(modelConfig.getModelName()) + .temperature(modelConfig.getTemperature()).topP(modelConfig.getTopP()) + .maxRetries(modelConfig.getMaxRetries()).logRequests(modelConfig.getLogRequests()) + .logResponses(modelConfig.getLogResponses()).build(); } @Override public EmbeddingModel createEmbeddingModel(EmbeddingModelConfig embeddingModelConfig) { - return ZhipuAiEmbeddingModel.builder() - .baseUrl(embeddingModelConfig.getBaseUrl()) - .apiKey(embeddingModelConfig.getApiKey()) - .model(embeddingModelConfig.getModelName()) + return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelConfig.getBaseUrl()) + .apiKey(embeddingModelConfig.getApiKey()).model(embeddingModelConfig.getModelName()) .maxRetries(embeddingModelConfig.getMaxRetries()) .logRequests(embeddingModelConfig.getLogRequests()) - .logResponses(embeddingModelConfig.getLogResponses()) - .build(); + .logResponses(embeddingModelConfig.getLogResponses()).build(); } @Override diff --git a/common/src/main/java/dev/langchain4j/qianfan/spring/Properties.java b/common/src/main/java/dev/langchain4j/qianfan/spring/Properties.java index dfdb8b41c..7f809283e 100644 --- a/common/src/main/java/dev/langchain4j/qianfan/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/qianfan/spring/Properties.java @@ -12,13 +12,18 @@ public class Properties { static final String PREFIX = "langchain4j.qianfan"; - @NestedConfigurationProperty ChatModelProperties chatModel; + @NestedConfigurationProperty + ChatModelProperties chatModel; - @NestedConfigurationProperty ChatModelProperties streamingChatModel; + @NestedConfigurationProperty + ChatModelProperties streamingChatModel; - @NestedConfigurationProperty LanguageModelProperties languageModel; + @NestedConfigurationProperty + LanguageModelProperties languageModel; - @NestedConfigurationProperty LanguageModelProperties streamingLanguageModel; + @NestedConfigurationProperty + LanguageModelProperties streamingLanguageModel; - @NestedConfigurationProperty EmbeddingModelProperties embeddingModel; + @NestedConfigurationProperty + EmbeddingModelProperties embeddingModel; } diff --git a/common/src/main/java/dev/langchain4j/qianfan/spring/QianfanAutoConfig.java b/common/src/main/java/dev/langchain4j/qianfan/spring/QianfanAutoConfig.java index 7bef69e52..e324d54a5 100644 --- a/common/src/main/java/dev/langchain4j/qianfan/spring/QianfanAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/qianfan/spring/QianfanAutoConfig.java @@ -20,8 +20,7 @@ public class QianfanAutoConfig { @ConditionalOnProperty(PREFIX + ".chat-model.api-key") QianfanChatModel qianfanChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getChatModel(); - return QianfanChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) + return QianfanChatModel.builder().baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) .secretKey(chatModelProperties.getSecretKey()) .endpoint(chatModelProperties.getEndpoint()) @@ -32,38 +31,32 @@ public class QianfanAutoConfig { .responseFormat(chatModelProperties.getResponseFormat()) .maxRetries(chatModelProperties.getMaxRetries()) .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .build(); + .logResponses(chatModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key") QianfanStreamingChatModel qianfanStreamingChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); - return QianfanStreamingChatModel.builder() - .endpoint(chatModelProperties.getEndpoint()) + return QianfanStreamingChatModel.builder().endpoint(chatModelProperties.getEndpoint()) .penaltyScore(chatModelProperties.getPenaltyScore()) .temperature(chatModelProperties.getTemperature()) - .topP(chatModelProperties.getTopP()) - .baseUrl(chatModelProperties.getBaseUrl()) + .topP(chatModelProperties.getTopP()).baseUrl(chatModelProperties.getBaseUrl()) .apiKey(chatModelProperties.getApiKey()) .secretKey(chatModelProperties.getSecretKey()) .modelName(chatModelProperties.getModelName()) .responseFormat(chatModelProperties.getResponseFormat()) .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .build(); + .logResponses(chatModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".language-model.api-key") QianfanLanguageModel qianfanLanguageModel(Properties properties) { LanguageModelProperties languageModelProperties = properties.getLanguageModel(); - return QianfanLanguageModel.builder() - .endpoint(languageModelProperties.getEndpoint()) + return QianfanLanguageModel.builder().endpoint(languageModelProperties.getEndpoint()) .penaltyScore(languageModelProperties.getPenaltyScore()) - .topK(languageModelProperties.getTopK()) - .topP(languageModelProperties.getTopP()) + .topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP()) .baseUrl(languageModelProperties.getBaseUrl()) .apiKey(languageModelProperties.getApiKey()) .secretKey(languageModelProperties.getSecretKey()) @@ -71,8 +64,7 @@ public class QianfanAutoConfig { .temperature(languageModelProperties.getTemperature()) .maxRetries(languageModelProperties.getMaxRetries()) .logRequests(languageModelProperties.getLogRequests()) - .logResponses(languageModelProperties.getLogResponses()) - .build(); + .logResponses(languageModelProperties.getLogResponses()).build(); } @Bean @@ -82,8 +74,7 @@ public class QianfanAutoConfig { return QianfanStreamingLanguageModel.builder() .endpoint(languageModelProperties.getEndpoint()) .penaltyScore(languageModelProperties.getPenaltyScore()) - .topK(languageModelProperties.getTopK()) - .topP(languageModelProperties.getTopP()) + .topK(languageModelProperties.getTopK()).topP(languageModelProperties.getTopP()) .baseUrl(languageModelProperties.getBaseUrl()) .apiKey(languageModelProperties.getApiKey()) .secretKey(languageModelProperties.getSecretKey()) @@ -91,16 +82,14 @@ public class QianfanAutoConfig { .temperature(languageModelProperties.getTemperature()) .maxRetries(languageModelProperties.getMaxRetries()) .logRequests(languageModelProperties.getLogRequests()) - .logResponses(languageModelProperties.getLogResponses()) - .build(); + .logResponses(languageModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".embedding-model.api-key") QianfanEmbeddingModel qianfanEmbeddingModel(Properties properties) { EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); - return QianfanEmbeddingModel.builder() - .baseUrl(embeddingModelProperties.getBaseUrl()) + return QianfanEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl()) .endpoint(embeddingModelProperties.getEndpoint()) .apiKey(embeddingModelProperties.getApiKey()) .secretKey(embeddingModelProperties.getSecretKey()) @@ -108,7 +97,6 @@ public class QianfanAutoConfig { .user(embeddingModelProperties.getUser()) .maxRetries(embeddingModelProperties.getMaxRetries()) .logRequests(embeddingModelProperties.getLogRequests()) - .logResponses(embeddingModelProperties.getLogResponses()) - .build(); + .logResponses(embeddingModelProperties.getLogResponses()).build(); } } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java index 234449d0b..cf86025ec 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreFactoryProvider.java @@ -27,24 +27,19 @@ public class EmbeddingStoreFactoryProvider { return ContextUtils.getBean(EmbeddingStoreFactory.class); } if (EmbeddingStoreType.CHROMA.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { - return factoryMap.computeIfAbsent( - embeddingStoreConfig, + return factoryMap.computeIfAbsent(embeddingStoreConfig, storeConfig -> new ChromaEmbeddingStoreFactory(storeConfig)); } if (EmbeddingStoreType.MILVUS.name().equalsIgnoreCase(embeddingStoreConfig.getProvider())) { - return factoryMap.computeIfAbsent( - embeddingStoreConfig, + return factoryMap.computeIfAbsent(embeddingStoreConfig, storeConfig -> new MilvusEmbeddingStoreFactory(storeConfig)); } - if (EmbeddingStoreType.IN_MEMORY - .name() + if (EmbeddingStoreType.IN_MEMORY.name() .equalsIgnoreCase(embeddingStoreConfig.getProvider())) { - return factoryMap.computeIfAbsent( - embeddingStoreConfig, + return factoryMap.computeIfAbsent(embeddingStoreConfig, storeConfig -> new InMemoryEmbeddingStoreFactory(storeConfig)); } - throw new RuntimeException( - "Unsupported EmbeddingStoreFactory provider: " - + embeddingStoreConfig.getProvider()); + throw new RuntimeException("Unsupported EmbeddingStoreFactory provider: " + + embeddingStoreConfig.getProvider()); } } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java index 7e2e0e3b3..068ac0ada 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/EmbeddingStoreType.java @@ -1,7 +1,5 @@ package dev.langchain4j.store.embedding; public enum EmbeddingStoreType { - IN_MEMORY, - MILVUS, - CHROMA + IN_MEMORY, MILVUS, CHROMA } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java b/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java index f8dc386c0..cfe24a643 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/Retrieval.java @@ -36,8 +36,7 @@ public class Retrieval { } Retrieval retrieval = (Retrieval) o; return Double.compare(retrieval.similarity, similarity) == 0 - && Objects.equal(id, retrieval.id) - && Objects.equal(query, retrieval.query) + && Objects.equal(id, retrieval.id) && Objects.equal(query, retrieval.query) && Objects.equal(metadata, retrieval.metadata); } diff --git a/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java b/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java index 938f99fe9..2278eefd7 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/TextSegmentConvert.java @@ -17,20 +17,12 @@ public class TextSegmentConvert { public static final String QUERY_ID = "queryId"; public static List convertToEmbedding(List dataItems) { - return dataItems.stream() - .map( - dataItem -> { - Map meta = - JSONObject.parseObject( - JSONObject.toJSONString(dataItem), Map.class); - TextSegment textSegment = - TextSegment.from(dataItem.getName(), new Metadata(meta)); - addQueryId( - textSegment, - dataItem.getId() + dataItem.getType().name().toLowerCase()); - return textSegment; - }) - .collect(Collectors.toList()); + return dataItems.stream().map(dataItem -> { + Map meta = JSONObject.parseObject(JSONObject.toJSONString(dataItem), Map.class); + TextSegment textSegment = TextSegment.from(dataItem.getName(), new Metadata(meta)); + addQueryId(textSegment, dataItem.getId() + dataItem.getType().name().toLowerCase()); + return textSegment; + }).collect(Collectors.toList()); } public static void addQueryId(TextSegment textSegment, String queryId) { diff --git a/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java b/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java index 237042c9e..5e3c7dd72 100644 --- a/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java +++ b/common/src/main/java/dev/langchain4j/store/embedding/inmemory/InMemoryEmbeddingStore.java @@ -40,16 +40,19 @@ import static java.util.stream.Collectors.toList; /** * An {@link EmbeddingStore} that stores embeddings in memory. * - *

Uses a brute force approach by iterating over all embeddings to find the best matches. + *

+ * Uses a brute force approach by iterating over all embeddings to find the best matches. * - *

This store can be persisted using the {@link #serializeToJson()} and {@link - * #serializeToFile(Path)} methods. + *

+ * This store can be persisted using the {@link #serializeToJson()} and + * {@link #serializeToFile(Path)} methods. * - *

It can also be recreated from JSON or a file using the {@link #fromJson(String)} and {@link - * #fromFile(Path)} methods. + *

+ * It can also be recreated from JSON or a file using the {@link #fromJson(String)} and + * {@link #fromFile(Path)} methods. * - * @param The class of the object that has been embedded. Typically, it is {@link - * dev.langchain4j.data.segment.TextSegment}. + * @param The class of the object that has been embedded. Typically, it is + * {@link dev.langchain4j.data.segment.TextSegment}. */ public class InMemoryEmbeddingStore implements EmbeddingStore { @@ -88,10 +91,8 @@ public class InMemoryEmbeddingStore implements EmbeddingStore addAll(List embeddings) { - List> newEntries = - embeddings.stream() - .map(embedding -> new Entry(randomUUID(), embedding)) - .collect(toList()); + List> newEntries = embeddings.stream() + .map(embedding -> new Entry(randomUUID(), embedding)).collect(toList()); return add(newEntries); } @@ -103,11 +104,9 @@ public class InMemoryEmbeddingStore implements EmbeddingStore> newEntries = - IntStream.range(0, embeddings.size()) - .mapToObj( - i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i))) - .collect(toList()); + List> newEntries = IntStream.range(0, embeddings.size()) + .mapToObj(i -> new Entry<>(randomUUID(), embeddings.get(i), embedded.get(i))) + .collect(toList()); return add(newEntries); } @@ -123,16 +122,15 @@ public class InMemoryEmbeddingStore implements EmbeddingStore { - if (entry.embedded instanceof TextSegment) { - return filter.test(((TextSegment) entry.embedded).metadata()); - } else if (entry.embedded == null) { - return false; - } else { - throw new UnsupportedOperationException("Not supported yet."); - } - }); + entries.removeIf(entry -> { + if (entry.embedded instanceof TextSegment) { + return filter.test(((TextSegment) entry.embedded).metadata()); + } else if (entry.embedded == null) { + return false; + } else { + throw new UnsupportedOperationException("Not supported yet."); + } + }); } @Override @@ -157,9 +155,8 @@ public class InMemoryEmbeddingStore implements EmbeddingStore= embeddingSearchRequest.minScore()) { matches.add(new EmbeddingMatch<>(score, entry.id, entry.embedding, entry.embedded)); @@ -247,8 +244,8 @@ public class InMemoryEmbeddingStore implements EmbeddingStore { private final boolean retrieveEmbeddingsOnSearch; private final boolean autoFlushOnInsert; - public MilvusEmbeddingStore( - String host, - Integer port, - String collectionName, - Integer dimension, - IndexType indexType, - MetricType metricType, - String uri, - String token, - String username, - String password, - ConsistencyLevelEnum consistencyLevel, - Boolean retrieveEmbeddingsOnSearch, - Boolean autoFlushOnInsert, - String databaseName) { + public MilvusEmbeddingStore(String host, Integer port, String collectionName, Integer dimension, + IndexType indexType, MetricType metricType, String uri, String token, String username, + String password, ConsistencyLevelEnum consistencyLevel, + Boolean retrieveEmbeddingsOnSearch, Boolean autoFlushOnInsert, String databaseName) { ConnectParam.Builder connectBuilder = - ConnectParam.newBuilder() - .withHost(getOrDefault(host, "localhost")) - .withPort(getOrDefault(port, 19530)) - .withUri(uri) - .withToken(token) + ConnectParam.newBuilder().withHost(getOrDefault(host, "localhost")) + .withPort(getOrDefault(port, 19530)).withUri(uri).withToken(token) .withAuthorization(getOrDefault(username, ""), getOrDefault(password, "")); if (databaseName != null) { @@ -93,12 +79,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore { this.autoFlushOnInsert = getOrDefault(autoFlushOnInsert, false); if (!hasCollection(this.milvusClient, this.collectionName)) { - createCollection( - this.milvusClient, this.collectionName, ensureNotNull(dimension, "dimension")); - createIndex( - this.milvusClient, - this.collectionName, - getOrDefault(indexType, FLAT), + createCollection(this.milvusClient, this.collectionName, + ensureNotNull(dimension, "dimension")); + createIndex(this.milvusClient, this.collectionName, getOrDefault(indexType, FLAT), this.metricType); } @@ -145,49 +128,36 @@ public class MilvusEmbeddingStore implements EmbeddingStore { public EmbeddingSearchResult search( EmbeddingSearchRequest embeddingSearchRequest) { - SearchParam searchParam = - buildSearchRequest( - collectionName, - embeddingSearchRequest.queryEmbedding().vectorAsList(), - embeddingSearchRequest.filter(), - embeddingSearchRequest.maxResults(), - metricType, - consistencyLevel); + SearchParam searchParam = buildSearchRequest(collectionName, + embeddingSearchRequest.queryEmbedding().vectorAsList(), + embeddingSearchRequest.filter(), embeddingSearchRequest.maxResults(), metricType, + consistencyLevel); SearchResultsWrapper resultsWrapper = CollectionOperationsExecutor.search(milvusClient, searchParam); - List> matches = - toEmbeddingMatches( - milvusClient, - resultsWrapper, - collectionName, - consistencyLevel, - retrieveEmbeddingsOnSearch); + List> matches = toEmbeddingMatches(milvusClient, resultsWrapper, + collectionName, consistencyLevel, retrieveEmbeddingsOnSearch); List> result = - matches.stream() - .filter(match -> match.score() >= embeddingSearchRequest.minScore()) + matches.stream().filter(match -> match.score() >= embeddingSearchRequest.minScore()) .collect(toList()); return new EmbeddingSearchResult<>(result); } private void addInternal(String id, Embedding embedding, TextSegment textSegment) { - addAllInternal( - singletonList(id), - singletonList(embedding), + addAllInternal(singletonList(id), singletonList(embedding), textSegment == null ? null : singletonList(textSegment)); } - private void addAllInternal( - List ids, List embeddings, List textSegments) { + private void addAllInternal(List ids, List embeddings, + List textSegments) { List fields = new ArrayList<>(); fields.add(new InsertParam.Field(ID_FIELD_NAME, ids)); fields.add(new InsertParam.Field(TEXT_FIELD_NAME, toScalars(textSegments, ids.size()))); - fields.add( - new InsertParam.Field( - METADATA_FIELD_NAME, toMetadataJsons(textSegments, ids.size()))); + fields.add(new InsertParam.Field(METADATA_FIELD_NAME, + toMetadataJsons(textSegments, ids.size()))); fields.add(new InsertParam.Field(VECTOR_FIELD_NAME, toVectors(embeddings))); insert(this.milvusClient, this.collectionName, fields); @@ -199,22 +169,22 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * Removes a single embedding from the store by ID. * - *

CAUTION + *

+ * CAUTION * *

    - *
  • Deleted entities can still be retrieved immediately after the deletion if the - * consistency level is set lower than {@code Strong} - *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be - * retrieved again. - *
  • Frequent deletion operations will impact the system performance. - *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has - * been loaded. - *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, - * if it fails halfway through, some data may still be deleted. - *
  • Deleting entities by complex boolean expressions is supported only when the consistency - * is set to Bounded. For details, see - * Consistency + *
  • Deleted entities can still be retrieved immediately after the deletion if the consistency + * level is set lower than {@code Strong} + *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be + * retrieved again. + *
  • Frequent deletion operations will impact the system performance. + *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has + * been loaded. + *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, + * if it fails halfway through, some data may still be deleted. + *
  • Deleting entities by complex boolean expressions is supported only when the consistency + * is set to Bounded. For details, + * see Consistency *
* * @param ids A collection of unique IDs of the embeddings to be removed. @@ -223,36 +193,34 @@ public class MilvusEmbeddingStore implements EmbeddingStore { @Override public void removeAll(Collection ids) { ensureNotEmpty(ids, "ids"); - removeForVector( - this.milvusClient, - this.collectionName, + removeForVector(this.milvusClient, this.collectionName, format("%s in %s", ID_FIELD_NAME, formatValues(ids))); } /** * Removes all embeddings that match the specified {@link Filter} from the store. * - *

CAUTION + *

+ * CAUTION * *

    - *
  • Deleted entities can still be retrieved immediately after the deletion if the - * consistency level is set lower than {@code Strong} - *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be - * retrieved again. - *
  • Frequent deletion operations will impact the system performance. - *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has - * been loaded. - *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, - * if it fails halfway through, some data may still be deleted. - *
  • Deleting entities by complex boolean expressions is supported only when the consistency - * is set to Bounded. For details, see - * Consistency + *
  • Deleted entities can still be retrieved immediately after the deletion if the consistency + * level is set lower than {@code Strong} + *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be + * retrieved again. + *
  • Frequent deletion operations will impact the system performance. + *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has + * been loaded. + *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, + * if it fails halfway through, some data may still be deleted. + *
  • Deleting entities by complex boolean expressions is supported only when the consistency + * is set to Bounded. For details, + * see Consistency *
* * @param filter The filter to be applied to the {@link Metadata} of the {@link TextSegment} - * during removal. Only embeddings whose {@code TextSegment}'s {@code Metadata} match the - * {@code Filter} will be removed. + * during removal. Only embeddings whose {@code TextSegment}'s {@code Metadata} match the + * {@code Filter} will be removed. * @since Milvus version 2.3.x */ @Override @@ -264,30 +232,30 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * Removes all embeddings from the store. * - *

CAUTION + *

+ * CAUTION * *

    - *
  • Deleted entities can still be retrieved immediately after the deletion if the - * consistency level is set lower than {@code Strong} - *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be - * retrieved again. - *
  • Frequent deletion operations will impact the system performance. - *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has - * been loaded. - *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, - * if it fails halfway through, some data may still be deleted. - *
  • Deleting entities by complex boolean expressions is supported only when the consistency - * is set to Bounded. For details, see - * Consistency + *
  • Deleted entities can still be retrieved immediately after the deletion if the consistency + * level is set lower than {@code Strong} + *
  • Entities deleted beyond the pre-specified span of time for Time Travel cannot be + * retrieved again. + *
  • Frequent deletion operations will impact the system performance. + *
  • Before deleting entities by comlpex boolean expressions, make sure the collection has + * been loaded. + *
  • Deleting entities by complex boolean expressions is not an atomic operation. Therefore, + * if it fails halfway through, some data may still be deleted. + *
  • Deleting entities by complex boolean expressions is supported only when the consistency + * is set to Bounded. For details, + * see Consistency *
* * @since Milvus version 2.3.x */ @Override public void removeAll() { - removeForVector( - this.milvusClient, this.collectionName, format("%s != \"\"", ID_FIELD_NAME)); + removeForVector(this.milvusClient, this.collectionName, + format("%s != \"\"", ID_FIELD_NAME)); } public static class Builder { @@ -327,7 +295,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param collectionName The name of the Milvus collection. If there is no such collection - * yet, it will be created automatically. Default value: "default". + * yet, it will be created automatically. Default value: "default". * @return builder */ public Builder collectionName(String collectionName) { @@ -337,7 +305,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param dimension The dimension of the embedding vector. (e.g. 384) Mandatory if a new - * collection should be created. + * collection should be created. * @return builder */ public Builder dimension(Integer dimension) { @@ -356,7 +324,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param metricType The type of the metric used for similarity search. Default value: - * COSINE. + * COSINE. * @return builder */ public Builder metricType(MetricType metricType) { @@ -366,7 +334,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param uri The URI of the managed Milvus instance. (e.g. - * "https://xxx.api.gcp-us-west1.zillizcloud.com") + * "https://xxx.api.gcp-us-west1.zillizcloud.com") * @return builder */ public Builder uri(String uri) { @@ -384,8 +352,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore { } /** - * @param username The username. See details here. + * @param username The username. See details + * here. * @return builder */ public Builder username(String username) { @@ -394,8 +362,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore { } /** - * @param password The password. See details here. + * @param password The password. See details + * here. * @return builder */ public Builder password(String password) { @@ -414,10 +382,10 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param retrieveEmbeddingsOnSearch During a similarity search in Milvus (when calling - * findRelevant()), the embedding itself is not retrieved. To retrieve the embedding, an - * additional query is required. Setting this parameter to "true" will ensure that - * embedding is retrieved. Be aware that this will impact the performance of the search. - * Default value: false. + * findRelevant()), the embedding itself is not retrieved. To retrieve the embedding, + * an additional query is required. Setting this parameter to "true" will ensure that + * embedding is retrieved. Be aware that this will impact the performance of the + * search. Default value: false. * @return builder */ public Builder retrieveEmbeddingsOnSearch(Boolean retrieveEmbeddingsOnSearch) { @@ -428,8 +396,8 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param autoFlushOnInsert Whether to automatically flush after each insert ({@code * add(...)} or {@code addAll(...)} methods). Default value: false. More info can be - * found here. + * found here. * @return builder */ public Builder autoFlushOnInsert(Boolean autoFlushOnInsert) { @@ -439,7 +407,7 @@ public class MilvusEmbeddingStore implements EmbeddingStore { /** * @param databaseName Milvus name of database. Default value: null. In this case default - * Milvus database name will be used. + * Milvus database name will be used. * @return builder */ public Builder databaseName(String databaseName) { @@ -448,21 +416,9 @@ public class MilvusEmbeddingStore implements EmbeddingStore { } public MilvusEmbeddingStore build() { - return new MilvusEmbeddingStore( - host, - port, - collectionName, - dimension, - indexType, - metricType, - uri, - token, - username, - password, - consistencyLevel, - retrieveEmbeddingsOnSearch, - autoFlushOnInsert, - databaseName); + return new MilvusEmbeddingStore(host, port, collectionName, dimension, indexType, + metricType, uri, token, username, password, consistencyLevel, + retrieveEmbeddingsOnSearch, autoFlushOnInsert, databaseName); } } } diff --git a/common/src/main/java/dev/langchain4j/zhipu/spring/Properties.java b/common/src/main/java/dev/langchain4j/zhipu/spring/Properties.java index 68bd040c2..f045d23fd 100644 --- a/common/src/main/java/dev/langchain4j/zhipu/spring/Properties.java +++ b/common/src/main/java/dev/langchain4j/zhipu/spring/Properties.java @@ -12,9 +12,12 @@ public class Properties { static final String PREFIX = "langchain4j.zhipu"; - @NestedConfigurationProperty ChatModelProperties chatModel; + @NestedConfigurationProperty + ChatModelProperties chatModel; - @NestedConfigurationProperty ChatModelProperties streamingChatModel; + @NestedConfigurationProperty + ChatModelProperties streamingChatModel; - @NestedConfigurationProperty EmbeddingModelProperties embeddingModel; + @NestedConfigurationProperty + EmbeddingModelProperties embeddingModel; } diff --git a/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java b/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java index 9b3ac945f..3eebaebf7 100644 --- a/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java +++ b/common/src/main/java/dev/langchain4j/zhipu/spring/ZhipuAutoConfig.java @@ -18,46 +18,36 @@ public class ZhipuAutoConfig { @ConditionalOnProperty(PREFIX + ".chat-model.api-key") ZhipuAiChatModel zhipuAiChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getChatModel(); - return ZhipuAiChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) - .apiKey(chatModelProperties.getApiKey()) - .model(chatModelProperties.getModelName()) + return ZhipuAiChatModel.builder().baseUrl(chatModelProperties.getBaseUrl()) + .apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName()) .temperature(chatModelProperties.getTemperature()) - .topP(chatModelProperties.getTopP()) - .maxRetries(chatModelProperties.getMaxRetries()) + .topP(chatModelProperties.getTopP()).maxRetries(chatModelProperties.getMaxRetries()) .maxToken(chatModelProperties.getMaxToken()) .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .build(); + .logResponses(chatModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".streaming-chat-model.api-key") ZhipuAiStreamingChatModel zhipuStreamingChatModel(Properties properties) { ChatModelProperties chatModelProperties = properties.getStreamingChatModel(); - return ZhipuAiStreamingChatModel.builder() - .baseUrl(chatModelProperties.getBaseUrl()) - .apiKey(chatModelProperties.getApiKey()) - .model(chatModelProperties.getModelName()) + return ZhipuAiStreamingChatModel.builder().baseUrl(chatModelProperties.getBaseUrl()) + .apiKey(chatModelProperties.getApiKey()).model(chatModelProperties.getModelName()) .temperature(chatModelProperties.getTemperature()) - .topP(chatModelProperties.getTopP()) - .maxToken(chatModelProperties.getMaxToken()) + .topP(chatModelProperties.getTopP()).maxToken(chatModelProperties.getMaxToken()) .logRequests(chatModelProperties.getLogRequests()) - .logResponses(chatModelProperties.getLogResponses()) - .build(); + .logResponses(chatModelProperties.getLogResponses()).build(); } @Bean @ConditionalOnProperty(PREFIX + ".embedding-model.api-key") ZhipuAiEmbeddingModel zhipuEmbeddingModel(Properties properties) { EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel(); - return ZhipuAiEmbeddingModel.builder() - .baseUrl(embeddingModelProperties.getBaseUrl()) + return ZhipuAiEmbeddingModel.builder().baseUrl(embeddingModelProperties.getBaseUrl()) .apiKey(embeddingModelProperties.getApiKey()) .model(embeddingModelProperties.getModel()) .maxRetries(embeddingModelProperties.getMaxRetries()) .logRequests(embeddingModelProperties.getLogRequests()) - .logResponses(embeddingModelProperties.getLogResponses()) - .build(); + .logResponses(embeddingModelProperties.getLogResponses()).build(); } } diff --git a/common/src/test/java/com/tencent/supersonic/common/DateUtilsTest.java b/common/src/test/java/com/tencent/supersonic/common/DateUtilsTest.java index 90986056d..074490ca3 100644 --- a/common/src/test/java/com/tencent/supersonic/common/DateUtilsTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/DateUtilsTest.java @@ -47,14 +47,8 @@ class DateUtilsTest { String startDate = "2023-07-29"; String endDate = "2023-08-03"; List actualDateList = DateUtils.getDateList(startDate, endDate, DatePeriodEnum.DAY); - List expectedDateList = - Lists.newArrayList( - "2023-07-29", - "2023-07-30", - "2023-07-31", - "2023-08-01", - "2023-08-02", - "2023-08-03"); + List expectedDateList = Lists.newArrayList("2023-07-29", "2023-07-30", "2023-07-31", + "2023-08-01", "2023-08-02", "2023-08-03"); Assertions.assertEquals(expectedDateList, actualDateList); } diff --git a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlParseUtilsTest.java b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlParseUtilsTest.java index 106d8d6c9..9c4b7daa0 100644 --- a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlParseUtilsTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlParseUtilsTest.java @@ -21,9 +21,8 @@ class SqlParseUtilsTest { void addAliasToSql() throws SqlParseException { String addAliasToSql = - SqlParseUtils.addAliasToSql( - "select sum(pv) from ( select * from t_1 " - + "where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' ) as t_sub_1"); + SqlParseUtils.addAliasToSql("select sum(pv) from ( select * from t_1 " + + "where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' ) as t_sub_1"); Assert.assertTrue(addAliasToSql.toLowerCase().contains("as pv")); } @@ -31,102 +30,75 @@ class SqlParseUtilsTest { @Test void addFieldToSql() throws SqlParseException { - String addFieldToSql = - SqlParseUtils.addFieldsToSql( - "select pv from ( select * from t_1 " - + "where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' ) as t_sub_1", - Collections.singletonList("uv")); + String addFieldToSql = SqlParseUtils.addFieldsToSql("select pv from ( select * from t_1 " + + "where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' ) as t_sub_1", + Collections.singletonList("uv")); Assert.assertTrue(addFieldToSql.toLowerCase().contains("uv")); - addFieldToSql = - SqlParseUtils.addFieldsToSql( - "select uv from ( select * from t_1 " - + "where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' ) as t_sub_1 " - + "order by play_count desc limit 10", - Collections.singletonList("pv")); + addFieldToSql = SqlParseUtils.addFieldsToSql("select uv from ( select * from t_1 " + + "where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' ) as t_sub_1 " + + "order by play_count desc limit 10", Collections.singletonList("pv")); Assert.assertTrue(addFieldToSql.toLowerCase().contains("pv")); - addFieldToSql = - SqlParseUtils.addFieldsToSql( - "select uv from " - + "( select * from t_1 where sys_imp_date >= '2023-07-07' " - + " and sys_imp_date <= '2023-07-07' " - + ") as t_sub_1 " - + "where user_id = '张三' order by play_count desc limit 10", - Collections.singletonList("pv")); + addFieldToSql = SqlParseUtils.addFieldsToSql( + "select uv from " + "( select * from t_1 where sys_imp_date >= '2023-07-07' " + + " and sys_imp_date <= '2023-07-07' " + ") as t_sub_1 " + + "where user_id = '张三' order by play_count desc limit 10", + Collections.singletonList("pv")); Assert.assertTrue(addFieldToSql.toLowerCase().contains("pv")); } @Test void getSqlParseInfo() { - SqlParserInfo sqlParserInfo = - SqlParseUtils.getSqlParseInfo( - "select pv from " - + "( select * from t_1 where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' )" - + " as t_sub_1 "); + SqlParserInfo sqlParserInfo = SqlParseUtils.getSqlParseInfo("select pv from " + + "( select * from t_1 where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' )" + + " as t_sub_1 "); Assert.assertTrue(sqlParserInfo.getTableName().equalsIgnoreCase("t_1")); - List collect = - sqlParserInfo.getAllFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + List collect = sqlParserInfo.getAllFields().stream() + .map(field -> field.toLowerCase()).collect(Collectors.toList()); Assert.assertTrue(collect.contains("pv")); Assert.assertTrue(!collect.contains("uv")); - List selectFields = - sqlParserInfo.getSelectFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + List selectFields = sqlParserInfo.getSelectFields().stream() + .map(field -> field.toLowerCase()).collect(Collectors.toList()); Assert.assertTrue(selectFields.contains("pv")); Assert.assertTrue(!selectFields.contains("uv")); - sqlParserInfo = - SqlParseUtils.getSqlParseInfo( - "select uv from t_1 order by play_count desc limit 10"); + sqlParserInfo = SqlParseUtils + .getSqlParseInfo("select uv from t_1 order by play_count desc limit 10"); Assert.assertTrue(sqlParserInfo.getTableName().equalsIgnoreCase("t_1")); - collect = - sqlParserInfo.getAllFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + collect = sqlParserInfo.getAllFields().stream().map(field -> field.toLowerCase()) + .collect(Collectors.toList()); Assert.assertTrue(collect.contains("uv")); Assert.assertTrue(collect.contains("play_count")); Assert.assertTrue(!collect.contains("pv")); - selectFields = - sqlParserInfo.getSelectFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + selectFields = sqlParserInfo.getSelectFields().stream().map(field -> field.toLowerCase()) + .collect(Collectors.toList()); Assert.assertTrue(selectFields.contains("uv")); Assert.assertTrue(!selectFields.contains("pv")); Assert.assertTrue(!selectFields.contains("play_count")); - sqlParserInfo = - SqlParseUtils.getSqlParseInfo( - "select uv from " - + "( " - + " select * from t_1 where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' " - + ") as t_sub_1 " - + "where user_id = '1' order by play_count desc limit 10"); + sqlParserInfo = SqlParseUtils.getSqlParseInfo("select uv from " + "( " + + " select * from t_1 where sys_imp_date >= '2023-07-07' and sys_imp_date <= '2023-07-07' " + + ") as t_sub_1 " + "where user_id = '1' order by play_count desc limit 10"); Assert.assertTrue(sqlParserInfo.getTableName().equalsIgnoreCase("t_1")); - collect = - sqlParserInfo.getAllFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + collect = sqlParserInfo.getAllFields().stream().map(field -> field.toLowerCase()) + .collect(Collectors.toList()); Assert.assertTrue(collect.contains("uv")); Assert.assertTrue(collect.contains("play_count")); Assert.assertTrue(collect.contains("user_id")); Assert.assertTrue(!collect.contains("pv")); - selectFields = - sqlParserInfo.getSelectFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + selectFields = sqlParserInfo.getSelectFields().stream().map(field -> field.toLowerCase()) + .collect(Collectors.toList()); Assert.assertTrue(selectFields.contains("uv")); Assert.assertTrue(!selectFields.contains("pv")); Assert.assertTrue(!selectFields.contains("user_id")); @@ -135,18 +107,12 @@ class SqlParseUtilsTest { @Test void getWhereFieldTest() { - SqlParserInfo sqlParserInfo = - SqlParseUtils.getSqlParseInfo( - "select uv from " - + " ( " - + " select * from t_1 where sys_imp_date >= '2023-07-07' and " - + "sys_imp_date <= '2023-07-07' and user_id = 22 " - + " ) as t_sub_1 " - + " where user_name_元 = 'zhangsan' order by play_count desc limit 10"); - List collect = - sqlParserInfo.getAllFields().stream() - .map(field -> field.toLowerCase()) - .collect(Collectors.toList()); + SqlParserInfo sqlParserInfo = SqlParseUtils.getSqlParseInfo("select uv from " + " ( " + + " select * from t_1 where sys_imp_date >= '2023-07-07' and " + + "sys_imp_date <= '2023-07-07' and user_id = 22 " + " ) as t_sub_1 " + + " where user_name_元 = 'zhangsan' order by play_count desc limit 10"); + List collect = sqlParserInfo.getAllFields().stream() + .map(field -> field.toLowerCase()).collect(Collectors.toList()); Assert.assertTrue(collect.contains("user_id")); } } diff --git a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java index 81206c8fb..76a16c5c1 100644 --- a/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/calcite/SqlWithMergerTest.java @@ -13,14 +13,11 @@ class SqlWithMergerTest { @Test void testWithMerger() throws SqlParseException { - String sql1 = - "WITH DepartmentVisits AS (\n" - + " SELECT department, SUM(pv) AS 总访问次数\n" - + " FROM t_1\n" - + " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n" - + " GROUP BY department\n" - + ")\n" - + "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100"; + String sql1 = "WITH DepartmentVisits AS (\n" + " SELECT department, SUM(pv) AS 总访问次数\n" + + " FROM t_1\n" + + " WHERE sys_imp_date >= '2024-09-01' AND sys_imp_date <= '2024-09-29'\n" + + " GROUP BY department\n" + ")\n" + + "SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100"; String sql2 = "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" @@ -28,35 +25,22 @@ class SqlWithMergerTest { + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + "FROM `s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; - String mergeSql = - SqlMergeWithUtils.mergeWith( - EngineType.MYSQL, - sql1, - Collections.singletonList(sql2), - Collections.singletonList("t_1")); + String mergeSql = SqlMergeWithUtils.mergeWith(EngineType.MYSQL, sql1, + Collections.singletonList(sql2), Collections.singletonList("t_1")); System.out.println(mergeSql); - sql1 = - "WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' " - + "AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"; + sql1 = "WITH DepartmentVisits AS (SELECT department, SUM(pv) AS 总访问次数 FROM t_1 WHERE sys_imp_date >= '2024-08-28' " + + "AND sys_imp_date <= '2024-09-28' GROUP BY department) SELECT COUNT(*) FROM DepartmentVisits WHERE 总访问次数 > 100 LIMIT 1000"; - sql2 = - "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" - + "FROM\n" - + "(SELECT `user_name`, `department`\n" - + "FROM\n" - + "`s2_user_department`) AS `t2`\n" - + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" - + "FROM\n" - + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; + sql2 = "SELECT `t3`.`sys_imp_date`, `t2`.`department`, `t3`.`s2_pv_uv_statis_pv` AS `pv`\n" + + "FROM\n" + "(SELECT `user_name`, `department`\n" + "FROM\n" + + "`s2_user_department`) AS `t2`\n" + + "LEFT JOIN (SELECT 1 AS `s2_pv_uv_statis_pv`, `imp_date` AS `sys_imp_date`, `user_name`\n" + + "FROM\n" + "`s2_pv_uv_statis`) AS `t3` ON `t2`.`user_name` = `t3`.`user_name`"; - mergeSql = - SqlMergeWithUtils.mergeWith( - EngineType.H2, - sql1, - Collections.singletonList(sql2), - Collections.singletonList("t_1")); + mergeSql = SqlMergeWithUtils.mergeWith(EngineType.H2, sql1, Collections.singletonList(sql2), + Collections.singletonList("t_1")); System.out.println(mergeSql); } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java index 962700e16..02a92d9c6 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAddHelperTest.java @@ -19,9 +19,8 @@ class SqlAddHelperTest { @Test void testAddWhere() throws JSQLParserException { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; sql = SqlAddHelper.addWhere(sql, "column_a", 123444555); List selectFields = SqlSelectHelper.getAllSelectFields(sql); @@ -35,53 +34,42 @@ class SqlAddHelperTest { Expression expression = CCJSqlParserUtil.parseCondExpression(" ( column_c = 111 or column_d = 1111)"); - sql = - SqlAddHelper.addWhere( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1", - expression); + sql = SqlAddHelper.addWhere("select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1", expression); Assert.assertEquals(sql.contains("column_c = 111"), true); - sql = - "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1"; + sql = "select 部门,sum (访问次数) from 超音数 where 用户 = alice or 发布日期 ='2023-07-03' group by 部门 limit 1"; sql = SqlAddHelper.addParenthesisToWhere(sql); sql = SqlAddHelper.addWhere(sql, "数据日期", "2023-08-08"); - Assert.assertEquals( - sql, - "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " - + "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1"); + Assert.assertEquals(sql, "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "(用户 = alice OR 发布日期 = '2023-07-03') AND 数据日期 = '2023-08-08' GROUP BY 部门 LIMIT 1"); } @Test void testAddFunctionToSelect() { - String sql = - "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; List havingExpressionList = SqlSelectHelper.getHavingExpression(sql); String replaceSql = SqlAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", replaceSql); - sql = - "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; havingExpressionList = SqlSelectHelper.getHavingExpression(sql); replaceSql = SqlAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", replaceSql); - sql = - "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " - + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; havingExpressionList = SqlSelectHelper.getHavingExpression(sql); replaceSql = SqlAddHelper.addFunctionToSelect(sql, havingExpressionList); @@ -94,33 +82,28 @@ class SqlAddHelperTest { @Test void testAddAggregateToField() { - String sql = - "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; List havingExpressionList = SqlSelectHelper.getHavingExpression(sql); String replaceSql = SqlAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", replaceSql); - sql = - "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; havingExpressionList = SqlSelectHelper.getHavingExpression(sql); replaceSql = SqlAddHelper.addFunctionToSelect(sql, havingExpressionList); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", + Assert.assertEquals("SELECT user_name, sum(pv) FROM 超音数 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000", replaceSql); - sql = - "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " - + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + sql = "SELECT user_name,sum(pv) FROM 超音数 WHERE (sys_imp_date <= '2023-09-03') AND " + + "sys_imp_date = '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; havingExpressionList = SqlSelectHelper.getHavingExpression(sql); replaceSql = SqlAddHelper.addFunctionToSelect(sql, havingExpressionList); @@ -145,14 +128,11 @@ class SqlAddHelperTest { String replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " - + "order by pv desc limit 10"; + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + + "order by pv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -165,23 +145,18 @@ class SqlAddHelperTest { replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); sql = "select department, pv from t_1 where sum(pv) >1 order by pv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sum(pv) > 1 " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 " - + "GROUP BY department order by pv desc limit 10"; + sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and sum(pv) >1 " + + "GROUP BY department order by pv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -190,9 +165,8 @@ class SqlAddHelperTest { + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " - + "GROUP BY department order by pv desc limit 10"; + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 " + + "GROUP BY department order by pv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -201,9 +175,8 @@ class SqlAddHelperTest { + "AND sum(pv) > 1 GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' " - + "GROUP BY department order by pv desc limit 10"; + sql = "select department, pv from t_1 where sys_imp_date = '2023-09-11' and pv >1 and department = 'HR' " + + "GROUP BY department order by pv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -212,9 +185,8 @@ class SqlAddHelperTest { + "AND department = 'HR' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, pv from t_1 where (pv >1 and department = 'HR') " - + " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10"; + sql = "select department, pv from t_1 where (pv >1 and department = 'HR') " + + " and sys_imp_date = '2023-09-11' GROUP BY department order by pv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -223,18 +195,15 @@ class SqlAddHelperTest { + "sys_imp_date = '2023-09-11' GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, sum(pv) as pv from t_1 where sys_imp_date = '2023-09-11' GROUP BY " - + "department order by pv desc limit 10"; + sql = "select department, sum(pv) as pv from t_1 where sys_imp_date = '2023-09-11' GROUP BY " + + "department order by pv desc limit 10"; replaceSql = SqlReplaceHelper.replaceAlias(sql); replaceSql = SqlAddHelper.addAggregateToField(replaceSql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); - Assert.assertEquals( - "SELECT department, sum(pv) AS pv " - + "FROM t_1 WHERE sys_imp_date = '2023-09-11' GROUP BY department " - + "ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT department, sum(pv) AS pv " + + "FROM t_1 WHERE sys_imp_date = '2023-09-11' GROUP BY department " + + "ORDER BY sum(pv) DESC LIMIT 10", replaceSql); } @Test @@ -256,9 +225,8 @@ class SqlAddHelperTest { + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 " - + "order by uv desc limit 10"; + sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 " + + "order by uv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -276,8 +244,7 @@ class SqlAddHelperTest { + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, uv from t_1 where count(DISTINCT uv) >1 order by uv desc limit 10"; + sql = "select department, uv from t_1 where count(DISTINCT uv) >1 order by uv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -286,10 +253,9 @@ class SqlAddHelperTest { + "GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11'" - + " and count(DISTINCT uv) >1 " - + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; + sql = "select department, count(DISTINCT uv) from t_1 where sys_imp_date = '2023-09-11'" + + " and count(DISTINCT uv) >1 " + + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -298,9 +264,8 @@ class SqlAddHelperTest { + "AND count(DISTINCT uv) > 1 GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 " - + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; + sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 " + + "GROUP BY department order by count(DISTINCT uv) desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -309,21 +274,18 @@ class SqlAddHelperTest { + "AND count(DISTINCT uv) > 1 GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 and department = 'HR' " - + "GROUP BY department order by uv desc limit 10"; + sql = "select department, uv from t_1 where sys_imp_date = '2023-09-11' and uv >1 and department = 'HR' " + + "GROUP BY department order by uv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); - Assert.assertEquals( - "SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = " - + "'2023-09-11' AND count(DISTINCT uv) > 1 " - + "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", + Assert.assertEquals("SELECT department, count(DISTINCT uv) FROM t_1 WHERE sys_imp_date = " + + "'2023-09-11' AND count(DISTINCT uv) > 1 " + + "AND department = 'HR' GROUP BY department ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, uv from t_1 where (uv >1 and department = 'HR') " - + " and sys_imp_date = '2023-09-11' GROUP BY department order by uv desc limit 10"; + sql = "select department, uv from t_1 where (uv >1 and department = 'HR') " + + " and sys_imp_date = '2023-09-11' GROUP BY department order by uv desc limit 10"; replaceSql = SqlAddHelper.addAggregateToField(sql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); @@ -334,39 +296,32 @@ class SqlAddHelperTest { + "count(DISTINCT uv) DESC LIMIT 10", replaceSql); - sql = - "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY " - + "department order by uv desc limit 10"; + sql = "select department, count(DISTINCT uv) as uv from t_1 where sys_imp_date = '2023-09-11' GROUP BY " + + "department order by uv desc limit 10"; replaceSql = SqlReplaceHelper.replaceAlias(sql); replaceSql = SqlAddHelper.addAggregateToField(replaceSql, filedNameToAggregate); replaceSql = SqlAddHelper.addGroupBy(replaceSql, groupByFields); - Assert.assertEquals( - "SELECT department, count(DISTINCT uv) AS uv " - + "FROM t_1 WHERE sys_imp_date = '2023-09-11' GROUP BY department " - + "ORDER BY count(DISTINCT uv) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT department, count(DISTINCT uv) AS uv " + + "FROM t_1 WHERE sys_imp_date = '2023-09-11' GROUP BY department " + + "ORDER BY count(DISTINCT uv) DESC LIMIT 10", replaceSql); } @Test void testAddGroupBy() { - String sql = - "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " - + "order by sum(pv) desc limit 10"; + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; Set groupByFields = new HashSet<>(); groupByFields.add("department"); String replaceSql = SqlAddHelper.addGroupBy(sql, groupByFields); - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' " - + "order by sum(pv) desc limit 10"; + sql = "select department, sum(pv) from t_1 where (department = 'HR') and sys_imp_date = '2023-09-11' " + + "order by sum(pv) desc limit 10"; replaceSql = SqlAddHelper.addGroupBy(sql, groupByFields); @@ -378,9 +333,8 @@ class SqlAddHelperTest { @Test void testAddHaving() { - String sql = - "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and " - + "sum(pv) > 2000 group by department order by sum(pv) desc limit 10"; + String sql = "select department, sum(pv) from t_1 where sys_imp_date = '2023-09-11' and " + + "sum(pv) > 2000 group by department order by sum(pv) desc limit 10"; List groupByFields = new ArrayList<>(); groupByFields.add("department"); @@ -389,47 +343,40 @@ class SqlAddHelperTest { String replaceSql = SqlAddHelper.addHaving(sql, fieldNames); - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); - sql = - "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' " - + "group by department order by sum(pv) desc limit 10"; + sql = "select department, sum(pv) from t_1 where (sum(pv) > 2000) and sys_imp_date = '2023-09-11' " + + "group by department order by sum(pv) desc limit 10"; replaceSql = SqlAddHelper.addHaving(sql, fieldNames); - Assert.assertEquals( - "SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " - + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", + Assert.assertEquals("SELECT department, sum(pv) FROM t_1 WHERE sys_imp_date = '2023-09-11' " + + "GROUP BY department HAVING sum(pv) > 2000 ORDER BY sum(pv) DESC LIMIT 10", replaceSql); } @Test void testAddParenthesisToWhere() { - String sql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; String replaceSql = SqlAddHelper.addParenthesisToWhere(sql); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') " - + "ORDER BY 播放量 DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01') " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); } @Test void testAddFieldsToSelect() { String correctS2SQL = "SELECT 用户, 页面 FROM 超音数用户部门 GROUP BY 用户, 页面 ORDER BY count(*) DESC"; - String replaceFields = - SqlAddHelper.addFieldsToSelect( - correctS2SQL, SqlSelectHelper.getOrderByFields(correctS2SQL)); + String replaceFields = SqlAddHelper.addFieldsToSelect(correctS2SQL, + SqlSelectHelper.getOrderByFields(correctS2SQL)); - Assert.assertEquals( - "SELECT 用户, 页面 FROM 超音数用户部门 GROUP BY 用户, 页面 ORDER BY count(*) DESC", replaceFields); + Assert.assertEquals("SELECT 用户, 页面 FROM 超音数用户部门 GROUP BY 用户, 页面 ORDER BY count(*) DESC", + replaceFields); } } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelperTest.java index b15e33dd7..1bbb691bf 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlAsHelperTest.java @@ -9,31 +9,14 @@ class SqlAsHelperTest { @Test void getAsFields() { - String sql = - "WITH SalesData AS (\n" - + " SELECT \n" - + " SalesID,\n" - + " ProductID,\n" - + " Quantity,\n" - + " Price,\n" - + " (Quantity * Price) AS TotalSales\n" - + " FROM \n" - + " Sales\n" - + ")\n" - + "SELECT \n" - + " ProductID,\n" - + " SUM(TotalSales) AS TotalRevenue,\n" - + " COUNT(SalesID) AS NumberOfSales\n" - + "FROM \n" - + " SalesData\n" - + "WHERE \n" - + " Quantity > 10\n" - + "GROUP BY \n" - + " ProductID\n" - + "HAVING \n" - + " SUM(TotalSales) > 1000\n" - + "ORDER BY \n" - + " TotalRevenue DESC"; + String sql = "WITH SalesData AS (\n" + " SELECT \n" + " SalesID,\n" + + " ProductID,\n" + " Quantity,\n" + " Price,\n" + + " (Quantity * Price) AS TotalSales\n" + " FROM \n" + " Sales\n" + + ")\n" + "SELECT \n" + " ProductID,\n" + + " SUM(TotalSales) AS TotalRevenue,\n" + " COUNT(SalesID) AS NumberOfSales\n" + + "FROM \n" + " SalesData\n" + "WHERE \n" + " Quantity > 10\n" + "GROUP BY \n" + + " ProductID\n" + "HAVING \n" + " SUM(TotalSales) > 1000\n" + "ORDER BY \n" + + " TotalRevenue DESC"; List asFields = SqlAsHelper.getAsFields(sql); Assert.assertTrue(asFields.contains("NumberOfSales")); Assert.assertTrue(asFields.contains("TotalRevenue")); diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlDateSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlDateSelectHelperTest.java index c5ae88741..046609b07 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlDateSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlDateSelectHelperTest.java @@ -9,38 +9,32 @@ class SqlDateSelectHelperTest { @Test void testGetDateBoundInfo() { - String sql = - "SELECT 维度1,sum(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; + String sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); Assert.assertEquals(dateBoundInfo.getLowerBound(), ">="); Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17"); - sql = - "SELECT 维度1,sum(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); Assert.assertEquals(dateBoundInfo.getLowerBound(), ">"); Assert.assertEquals(dateBoundInfo.getLowerDate(), "2023-11-17"); - sql = - "SELECT 维度1,sum(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); Assert.assertEquals(dateBoundInfo.getUpperBound(), "<="); Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17"); - sql = - "SELECT 维度1,sum(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); Assert.assertEquals(dateBoundInfo.getUpperBound(), "<"); Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17"); - sql = - "SELECT 维度1,sum(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-10-17' " - + "AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-10-17' " + + "AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(sql); Assert.assertEquals(dateBoundInfo.getUpperBound(), "<="); Assert.assertEquals(dateBoundInfo.getUpperDate(), "2023-11-17"); diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java index 54e6a8112..4cf5f0d13 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlRemoveHelperTest.java @@ -26,63 +26,49 @@ class SqlRemoveHelperTest { @Test void testRemoveSameFieldFromSelect() { - String sql = - "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)" - + " from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " - + "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1"; + String sql = "select 歌曲名,歌手名,粉丝数,粉丝数,sum(粉丝数),sum(粉丝数),avg(播放量),avg(播放量)" + + " from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " + + "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1"; sql = SqlRemoveHelper.removeSameFieldFromSelect(sql); System.out.println(sql); - sql = - "SELECT 结算播放量 FROM 艺人 WHERE (歌手名 IN ('林俊杰', '陈奕迅')) AND (数据日期 >= '2024-04-04' AND 数据日期 <= '2024-04-04')"; + sql = "SELECT 结算播放量 FROM 艺人 WHERE (歌手名 IN ('林俊杰', '陈奕迅')) AND (数据日期 >= '2024-04-04' AND 数据日期 <= '2024-04-04')"; List fieldExpressionList = SqlSelectHelper.getWhereExpressions(sql); - fieldExpressionList.stream() - .forEach( - fieldExpression -> { - System.out.println(fieldExpression.toString()); - }); + fieldExpressionList.stream().forEach(fieldExpression -> { + System.out.println(fieldExpression.toString()); + }); } @Test void testRemoveWhereHavingCondition() { - String sql = - "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " - + "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1"; + String sql = "select 歌曲名 from 歌曲库 where sum(粉丝数) > 20000 and 2>1 and " + + "sum(播放量) > 20000 and 1=1 HAVING sum(播放量) > 20000 and 3>1"; sql = SqlRemoveHelper.removeNumberFilter(sql); System.out.println(sql); Assert.assertEquals( "SELECT 歌曲名 FROM 歌曲库 WHERE sum(粉丝数) > 20000 AND sum(播放量) > 20000 HAVING sum(播放量) > 20000", sql); - sql = - "SELECT 歌曲,sum(播放量) FROM 歌曲库\n" - + "WHERE (歌手名 = '张三' AND 2 > 1) AND 数据日期 = '2023-11-07'\n" - + "GROUP BY 歌曲名 HAVING sum(播放量) > 100000"; + sql = "SELECT 歌曲,sum(播放量) FROM 歌曲库\n" + + "WHERE (歌手名 = '张三' AND 2 > 1) AND 数据日期 = '2023-11-07'\n" + + "GROUP BY 歌曲名 HAVING sum(播放量) > 100000"; sql = SqlRemoveHelper.removeNumberFilter(sql); System.out.println(sql); - Assert.assertEquals( - "SELECT 歌曲, sum(播放量) FROM 歌曲库 WHERE (歌手名 = '张三') " - + "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING sum(播放量) > 100000", - sql); - sql = - "SELECT 歌曲名,sum(播放量) FROM 歌曲库 WHERE (1 = 1 AND 1 = 1 AND 2 > 1 )" - + "AND 1 = 1 AND 歌曲类型 IN ('类型一', '类型二') AND 歌手名 IN ('林俊杰', '周杰伦')" - + "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING 2 > 1 AND SUM(播放量) >= 1000"; + Assert.assertEquals("SELECT 歌曲, sum(播放量) FROM 歌曲库 WHERE (歌手名 = '张三') " + + "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING sum(播放量) > 100000", sql); + sql = "SELECT 歌曲名,sum(播放量) FROM 歌曲库 WHERE (1 = 1 AND 1 = 1 AND 2 > 1 )" + + "AND 1 = 1 AND 歌曲类型 IN ('类型一', '类型二') AND 歌手名 IN ('林俊杰', '周杰伦')" + + "AND 数据日期 = '2023-11-07' GROUP BY 歌曲名 HAVING 2 > 1 AND SUM(播放量) >= 1000"; sql = SqlRemoveHelper.removeNumberFilter(sql); System.out.println(sql); - Assert.assertEquals( - "SELECT 歌曲名, sum(播放量) FROM 歌曲库 WHERE 歌曲类型 IN ('类型一', '类型二') " - + "AND 歌手名 IN ('林俊杰', '周杰伦') AND 数据日期 = '2023-11-07' " - + "GROUP BY 歌曲名 HAVING SUM(播放量) >= 1000", - sql); + Assert.assertEquals("SELECT 歌曲名, sum(播放量) FROM 歌曲库 WHERE 歌曲类型 IN ('类型一', '类型二') " + + "AND 歌手名 IN ('林俊杰', '周杰伦') AND 数据日期 = '2023-11-07' " + + "GROUP BY 歌曲名 HAVING SUM(播放量) >= 1000", sql); - sql = - "SELECT 品牌名称,法人 FROM 互联网企业 WHERE (2 > 1 AND 1 = 1) AND 数据日期 = '2023-10-31'" - + "GROUP BY 品牌名称, 法人 HAVING 2 > 1 AND sum(注册资本) > 100000000 AND sum(营收占比) = 0.5 and 1 = 1"; + sql = "SELECT 品牌名称,法人 FROM 互联网企业 WHERE (2 > 1 AND 1 = 1) AND 数据日期 = '2023-10-31'" + + "GROUP BY 品牌名称, 法人 HAVING 2 > 1 AND sum(注册资本) > 100000000 AND sum(营收占比) = 0.5 and 1 = 1"; sql = SqlRemoveHelper.removeNumberFilter(sql); System.out.println(sql); - Assert.assertEquals( - "SELECT 品牌名称, 法人 FROM 互联网企业 WHERE 数据日期 = '2023-10-31' GROUP BY " - + "品牌名称, 法人 HAVING sum(注册资本) > 100000000 AND sum(营收占比) = 0.5", - sql); + Assert.assertEquals("SELECT 品牌名称, 法人 FROM 互联网企业 WHERE 数据日期 = '2023-10-31' GROUP BY " + + "品牌名称, 法人 HAVING sum(注册资本) > 100000000 AND sum(营收占比) = 0.5", sql); } @Test @@ -96,42 +82,33 @@ class SqlRemoveHelperTest { @Test void testRemoveWhereCondition() { - String sql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + String sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; Set removeFieldNames = new HashSet<>(); removeFieldNames.add("歌曲名"); String replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " - + "ORDER BY 播放量 DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 数据日期 = '2023-08-09' AND 歌曲发布时 = '2023-08-01' " + + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); - sql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + sql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋') and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 数据日期 = '2023-08-09' AND " - + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 数据日期 = '2023-08-09' AND " + + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - sql = - "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' " - + " order by 播放量 desc limit 11"; + sql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 in ('邓紫棋','周杰伦') and 歌曲名 in ('邓紫棋')) and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; replaceSql = SqlRemoveHelper.removeWhereCondition(sql, removeFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1) " - + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE (datediff('day', 发布日期, '2023-08-09') <= 1) " + + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); } @Test @@ -147,8 +124,7 @@ class SqlRemoveHelperTest { "SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '邓紫棋' AND 数据日期 = '2023-08-09' AND 歌曲发布时间 = '2023-08-01'", replaceSql); - sql = - "select 数据日期 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时间 = '2023-08-01'"; + sql = "select 数据日期 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时间 = '2023-08-01'"; replaceSql = SqlRemoveHelper.removeSelect(sql, removeFieldNames); @@ -159,9 +135,8 @@ class SqlRemoveHelperTest { @Test void testRemoveGroupBy() { - String sql = - "select 数据日期 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and " - + "歌曲发布时间 = '2023-08-01' group by 数据日期"; + String sql = "select 数据日期 from 歌曲库 where 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and " + + "歌曲发布时间 = '2023-08-01' group by 数据日期"; Set removeFieldNames = new HashSet<>(); removeFieldNames.add("数据日期"); @@ -174,9 +149,8 @@ class SqlRemoveHelperTest { @Test void testRemoveIsNullInWhere() { - String sql = - "select 数据日期 from 歌曲库 where 歌曲名 is null and 数据日期 = '2023-08-09' and " - + "歌曲发布时间 = '2023-08-01' group by 数据日期"; + String sql = "select 数据日期 from 歌曲库 where 歌曲名 is null and 数据日期 = '2023-08-09' and " + + "歌曲发布时间 = '2023-08-01' group by 数据日期"; Set removeFieldNames = new HashSet<>(); removeFieldNames.add("歌曲名"); @@ -186,9 +160,8 @@ class SqlRemoveHelperTest { "SELECT 数据日期 FROM 歌曲库 WHERE 数据日期 = '2023-08-09' AND 歌曲发布时间 = '2023-08-01' GROUP BY 数据日期", replaceSql); - sql = - "select 数据日期 from 歌曲库 where 歌曲名 is null and 数据日期 = '2023-08-09' and " - + "歌曲发布时间 = '2023-08-01' group by 数据日期 having 歌曲名 is null"; + sql = "select 数据日期 from 歌曲库 where 歌曲名 is null and 数据日期 = '2023-08-09' and " + + "歌曲发布时间 = '2023-08-01' group by 数据日期 having 歌曲名 is null"; replaceSql = SqlRemoveHelper.removeIsNullInWhere(sql, removeFieldNames); @@ -199,9 +172,8 @@ class SqlRemoveHelperTest { @Test void testRemoveIsNotNullInWhere() { - String sql = - "select 数据日期 from 歌曲库 where 歌曲名 is not null and 数据日期 = '2023-08-09' and " - + "歌曲发布时间 = '2023-08-01' group by 数据日期"; + String sql = "select 数据日期 from 歌曲库 where 歌曲名 is not null and 数据日期 = '2023-08-09' and " + + "歌曲发布时间 = '2023-08-01' group by 数据日期"; Set removeFieldNames = new HashSet<>(); removeFieldNames.add("歌曲名"); @@ -211,9 +183,8 @@ class SqlRemoveHelperTest { "SELECT 数据日期 FROM 歌曲库 WHERE 数据日期 = '2023-08-09' AND 歌曲发布时间 = '2023-08-01' GROUP BY 数据日期", replaceSql); - sql = - "select 数据日期 from 歌曲库 where 歌曲名 is not null and 数据日期 = '2023-08-09' and " - + "歌曲发布时间 = '2023-08-01' group by 数据日期 having 歌曲名 is not null"; + sql = "select 数据日期 from 歌曲库 where 歌曲名 is not null and 数据日期 = '2023-08-09' and " + + "歌曲发布时间 = '2023-08-01' group by 数据日期 having 歌曲名 is not null"; replaceSql = SqlRemoveHelper.removeNotNullInWhere(sql, removeFieldNames); diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java index 3af5fdb79..19af21a0c 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlReplaceHelperTest.java @@ -16,26 +16,22 @@ class SqlReplaceHelperTest { @Test void testReplaceAggField() { - String sql = - "SELECT 维度1,sum(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1"; + String sql = "SELECT 维度1,sum(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1"; Map> fieldMap = new HashMap<>(); fieldMap.put("播放量", Pair.of("收听用户数", AggOperatorEnum.COUNT_DISTINCT.name())); sql = SqlReplaceHelper.replaceAggFields(sql, fieldMap); System.out.println(sql); - Assert.assertEquals( - "SELECT 维度1, count(DISTINCT 收听用户数) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", - sql); + Assert.assertEquals("SELECT 维度1, count(DISTINCT 收听用户数) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 = '2023-11-17' GROUP BY 维度1", sql); } @Test void testReplaceValue() { - String replaceSql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '杰伦' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; Map> filedNameToValueMap = new HashMap<>(); @@ -51,10 +47,9 @@ class SqlReplaceHelperTest { + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); - replaceSql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; Map> filedNameToValueMap2 = new HashMap<>(); @@ -72,10 +67,9 @@ class SqlReplaceHelperTest { + "歌曲发布时 = '2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - replaceSql = - "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' " - + " order by 播放量 desc limit 11"; + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); @@ -85,12 +79,10 @@ class SqlReplaceHelperTest { + "AND 数据日期 = '2023-08-09' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - replaceSql = - "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01' and 播放量 < (" - + "select min(播放量) from 歌曲库 where 语种 = '英文' " - + ") ) and 数据日期 = '2023-08-09' " - + " order by 播放量 desc limit 11"; + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '周杰' and 歌手名 = '林俊' and 歌手名 = '陈' and 歌曲发布时 = '2023-08-01' and 播放量 < (" + + "select min(播放量) from 歌曲库 where 语种 = '英文' " + ") ) and 数据日期 = '2023-08-09' " + + " order by 播放量 desc limit 11"; replaceSql = SqlReplaceHelper.replaceValue(replaceSql, filedNameToValueMap2, false); @@ -117,10 +109,9 @@ class SqlReplaceHelperTest { @Test void testReplaceFieldNameByValue() { - String replaceSql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; Map> fieldValueToFieldNames = new HashMap<>(); fieldValueToFieldNames.put("邓紫棋", Collections.singleton("歌手名")); @@ -133,18 +124,15 @@ class SqlReplaceHelperTest { + "ORDER BY 播放量 DESC LIMIT 11", replaceSql); - replaceSql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌曲名 like '%邓紫棋%' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌曲名 like '%邓紫棋%' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; replaceSql = SqlReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 歌曲名 LIKE '%邓紫棋%' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " - + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌曲名 LIKE '%邓紫棋%' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " + + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); Set fieldNames = new HashSet<>(); fieldNames.add("歌手名"); @@ -152,23 +140,19 @@ class SqlReplaceHelperTest { fieldNames.add("专辑名"); fieldValueToFieldNames.put("林俊杰", fieldNames); - replaceSql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '林俊杰' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '林俊杰' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; replaceSql = SqlReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); - Assert.assertEquals( - "SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " - + "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " - + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE datediff('day', 发布日期, '2023-08-09') <= 1 " + + "AND 歌手名 = '林俊杰' AND 数据日期 = '2023-08-09' AND 歌曲发布时 = " + + "'2023-08-01' ORDER BY 播放量 DESC LIMIT 11", replaceSql); - replaceSql = - "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'" - + " order by 播放量 desc limit 11"; + replaceSql = "select 歌曲名 from 歌曲库 where (datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '林俊杰' and 歌曲发布时 = '2023-08-01') and 数据日期 = '2023-08-09'" + + " order by 播放量 desc limit 11"; replaceSql = SqlReplaceHelper.replaceFieldNameByValue(replaceSql, fieldValueToFieldNames); @@ -184,9 +168,8 @@ class SqlReplaceHelperTest { Map fieldToBizName1 = new HashMap<>(); fieldToBizName1.put("公司成立时间", "company_established_time"); fieldToBizName1.put("年营业额", "annual_turnover"); - String replaceSql = - "SELECT * FROM 互联网企业 ORDER BY 公司成立时间 DESC LIMIT 3 " - + "UNION SELECT * FROM 互联网企业 ORDER BY 年营业额 DESC LIMIT 5"; + String replaceSql = "SELECT * FROM 互联网企业 ORDER BY 公司成立时间 DESC LIMIT 3 " + + "UNION SELECT * FROM 互联网企业 ORDER BY 年营业额 DESC LIMIT 5"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName1); replaceSql = SqlReplaceHelper.replaceTable(replaceSql, "internet"); Assert.assertEquals( @@ -199,10 +182,9 @@ class SqlReplaceHelperTest { void testReplaceFields() { Map fieldToBizName = initParams(); - String replaceSql = - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" - + " order by 播放量 desc limit 11"; + String replaceSql = "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' and 歌曲发布时 = '2023-08-01'" + + " order by 播放量 desc limit 11"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); @@ -220,10 +202,9 @@ class SqlReplaceHelperTest { "SELECT 品牌名称 FROM 互联网企业 WHERE 品牌成立时间 < '2006-11-04' AND 注册资本 = 50000000", replaceSql); - replaceSql = - "select MONTH(数据日期), sum(访问次数) from 内容库产品 " - + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " - + "group by MONTH(数据日期) order by sum(访问次数) desc limit 1"; + replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " + + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " + + "group by MONTH(数据日期) order by sum(访问次数) desc limit 1"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); @@ -234,10 +215,9 @@ class SqlReplaceHelperTest { + " GROUP BY MONTH(sys_imp_date) ORDER BY sum(pv) DESC LIMIT 1", replaceSql); - replaceSql = - "select MONTH(数据日期), sum(访问次数) from 内容库产品 " - + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " - + "group by MONTH(数据日期) HAVING sum(访问次数) > 1000"; + replaceSql = "select MONTH(数据日期), sum(访问次数) from 内容库产品 " + + "where datediff('year', 数据日期, '2023-09-03') <= 0.5 " + + "group by MONTH(数据日期) HAVING sum(访问次数) > 1000"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); @@ -247,38 +227,30 @@ class SqlReplaceHelperTest { + " sys_imp_date <= '2023-09-03') GROUP BY MONTH(sys_imp_date) HAVING sum(pv) > 1000", replaceSql); - replaceSql = - "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) " - + "in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)"; + replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 where YEAR(发行日期) " + + "in (2022, 2023) and 数据日期 = '2023-08-14' group by YEAR(发行日期)"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - Assert.assertEquals( - "SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " - + "GROUP BY YEAR(publish_date)", - replaceSql); + Assert.assertEquals("SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " + + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14' " + + "GROUP BY YEAR(publish_date)", replaceSql); - replaceSql = - "select YEAR(发行日期), count(歌曲名) from 歌曲库 " - + "where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' " - + "group by 发行日期"; + replaceSql = "select YEAR(发行日期), count(歌曲名) from 歌曲库 " + + "where YEAR(发行日期) in (2022, 2023) and 数据日期 = '2023-08-14' " + "group by 发行日期"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - Assert.assertEquals( - "SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " - + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" - + " GROUP BY publish_date", - replaceSql); + Assert.assertEquals("SELECT YEAR(发行日期), count(song_name) FROM 歌曲库 " + + "WHERE YEAR(发行日期) IN (2022, 2023) AND sys_imp_date = '2023-08-14'" + + " GROUP BY publish_date", replaceSql); - replaceSql = - SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 " - + "and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30", - fieldToBizName); + replaceSql = SqlReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-11') <= 1 " + + "and 结算播放量 > 1000000 and datediff('day', 数据日期, '2023-08-11') <= 30", + fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( @@ -287,11 +259,10 @@ class SqlReplaceHelperTest { + "(sys_imp_date >= '2023-07-12' AND sys_imp_date <= '2023-08-11')", replaceSql); - replaceSql = - SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); + replaceSql = SqlReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", + fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( @@ -299,11 +270,10 @@ class SqlReplaceHelperTest { + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' ORDER BY play_count DESC LIMIT 11", replaceSql); - replaceSql = - SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); + replaceSql = SqlReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') = 0 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", + fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( @@ -311,11 +281,10 @@ class SqlReplaceHelperTest { + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' ORDER BY play_count DESC LIMIT 11", replaceSql); - replaceSql = - SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); + replaceSql = SqlReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') <= 0.5 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", + fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( @@ -323,24 +292,19 @@ class SqlReplaceHelperTest { + " AND singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09' ORDER BY play_count DESC LIMIT 11", replaceSql); - replaceSql = - SqlReplaceHelper.replaceFields( - "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 " - + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", - fieldToBizName); + replaceSql = SqlReplaceHelper.replaceFields( + "select 歌曲名 from 歌曲库 where datediff('year', 发布日期, '2023-08-09') >= 0.5 " + + "and 歌手名 = '邓紫棋' and 数据日期 = '2023-08-09' order by 播放量 desc limit 11", + fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); replaceSql = SqlRemoveHelper.removeNumberFilter(replaceSql); - Assert.assertEquals( - "SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-02-09' AND" - + " singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09'" - + " ORDER BY play_count DESC LIMIT 11", - replaceSql); + Assert.assertEquals("SELECT song_name FROM 歌曲库 WHERE publish_date <= '2023-02-09' AND" + + " singer_name = '邓紫棋' AND sys_imp_date = '2023-08-09'" + + " ORDER BY play_count DESC LIMIT 11", replaceSql); - replaceSql = - SqlReplaceHelper.replaceFields( - "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice'" - + " and 发布日期 ='11' order by 访问次数 desc limit 1", - fieldToBizName); + replaceSql = SqlReplaceHelper + .replaceFields("select 部门,用户 from 超音数 where 数据日期 = '2023-08-08' and 用户 ='alice'" + + " and 发布日期 ='11' order by 访问次数 desc limit 1", fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); Assert.assertEquals( @@ -360,25 +324,20 @@ class SqlReplaceHelperTest { fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - Assert.assertEquals( - "SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1", + Assert.assertEquals("SELECT department, sum(pv) FROM 超音数 WHERE sys_imp_date = '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11' GROUP BY department LIMIT 1", replaceSql); - replaceSql = - "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' " - + "and 数据日期 <= '2023-08-06' and 部门 = 'hr'"; + replaceSql = "select sum(访问次数) from 超音数 where 数据日期 >= '2023-08-06' " + + "and 数据日期 <= '2023-08-06' and 部门 = 'hr'"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - Assert.assertEquals( - "SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' " - + "AND sys_imp_date <= '2023-08-06' AND department = 'hr'", - replaceSql); + Assert.assertEquals("SELECT sum(pv) FROM 超音数 WHERE sys_imp_date >= '2023-08-06' " + + "AND sys_imp_date <= '2023-08-06' AND department = 'hr'", replaceSql); - replaceSql = - "SELECT 歌曲名称, sum(评分) FROM CSpider WHERE(1 < 2) AND 数据日期 = '2023-10-15' " - + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; + replaceSql = "SELECT 歌曲名称, sum(评分) FROM CSpider WHERE(1 < 2) AND 数据日期 = '2023-10-15' " + + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); @@ -388,10 +347,9 @@ class SqlReplaceHelperTest { + "sum(评分) < (SELECT min(评分) FROM CSpider WHERE user_id = '英文')", replaceSql); - replaceSql = - "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')" - + " FROM CSpider WHERE 数据日期 = '2023-10-15' " - + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; + replaceSql = "SELECT sum(评分)/ (SELECT sum(评分) FROM CSpider WHERE 数据日期 = '2023-10-15')" + + " FROM CSpider WHERE 数据日期 = '2023-10-15' " + + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider WHERE 语种 = '英文')"; replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); @@ -410,77 +368,59 @@ class SqlReplaceHelperTest { replaceSql = SqlReplaceHelper.replaceFields(replaceSql, fieldToBizName); replaceSql = SqlReplaceHelper.replaceFunction(replaceSql); - Assert.assertEquals( - "SELECT TIMESTAMPDIFF(MONTH, song_publis_date, CURDATE()) AS 发布月数 " - + "FROM 歌曲库 WHERE singer_name = '邓紫棋'", - replaceSql); + Assert.assertEquals("SELECT TIMESTAMPDIFF(MONTH, song_publis_date, CURDATE()) AS 发布月数 " + + "FROM 歌曲库 WHERE singer_name = '邓紫棋'", replaceSql); } @Test void testReplaceTable() { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; String replaceSql = SqlReplaceHelper.replaceTable(sql, "s2"); - Assert.assertEquals( - "SELECT 部门, sum(访问次数) FROM s2 WHERE 数据日期 = '2023-08-08' " - + "AND 用户 = alice AND 发布日期 = '11' GROUP BY 部门 LIMIT 1", - replaceSql); + Assert.assertEquals("SELECT 部门, sum(访问次数) FROM s2 WHERE 数据日期 = '2023-08-08' " + + "AND 用户 = alice AND 发布日期 = '11' GROUP BY 部门 LIMIT 1", replaceSql); - sql = - "select * from 互联网企业 order by 公司成立时间 desc limit 3 union select * from 互联网企业 order by 年营业额 desc limit 5"; + sql = "select * from 互联网企业 order by 公司成立时间 desc limit 3 union select * from 互联网企业 order by 年营业额 desc limit 5"; replaceSql = SqlReplaceHelper.replaceTable(sql, "internet"); - Assert.assertEquals( - "SELECT * FROM internet ORDER BY 公司成立时间 DESC LIMIT 3 " - + "UNION SELECT * FROM internet ORDER BY 年营业额 DESC LIMIT 5", - replaceSql); + Assert.assertEquals("SELECT * FROM internet ORDER BY 公司成立时间 DESC LIMIT 3 " + + "UNION SELECT * FROM internet ORDER BY 年营业额 DESC LIMIT 5", replaceSql); - sql = - "SELECT * FROM CSpider音乐 WHERE (评分 < (SELECT min(评分) " - + "FROM CSpider音乐 WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'"; + sql = "SELECT * FROM CSpider音乐 WHERE (评分 < (SELECT min(评分) " + + "FROM CSpider音乐 WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'"; replaceSql = SqlReplaceHelper.replaceTable(sql, "cspider"); - Assert.assertEquals( - "SELECT * FROM cspider WHERE (评分 < (SELECT min(评分) FROM " - + "cspider WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'", - replaceSql); + Assert.assertEquals("SELECT * FROM cspider WHERE (评分 < (SELECT min(评分) FROM " + + "cspider WHERE 语种 = '英文')) AND 数据日期 = '2023-10-11'", replaceSql); - sql = - "SELECT 歌曲名称, sum(评分) FROM CSpider音乐 WHERE(1 < 2) AND 数据日期 = '2023-10-15' " - + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider音乐 WHERE 语种 = '英文')"; + sql = "SELECT 歌曲名称, sum(评分) FROM CSpider音乐 WHERE(1 < 2) AND 数据日期 = '2023-10-15' " + + "GROUP BY 歌曲名称 HAVING sum(评分) < ( SELECT min(评分) FROM CSpider音乐 WHERE 语种 = '英文')"; replaceSql = SqlReplaceHelper.replaceTable(sql, "cspider"); - Assert.assertEquals( - "SELECT 歌曲名称, sum(评分) FROM cspider WHERE (1 < 2) AND 数据日期 = " - + "'2023-10-15' GROUP BY 歌曲名称 HAVING sum(评分) < (SELECT min(评分) " - + "FROM cspider WHERE 语种 = '英文')", - replaceSql); + Assert.assertEquals("SELECT 歌曲名称, sum(评分) FROM cspider WHERE (1 < 2) AND 数据日期 = " + + "'2023-10-15' GROUP BY 歌曲名称 HAVING sum(评分) < (SELECT min(评分) " + + "FROM cspider WHERE 语种 = '英文')", replaceSql); } @Test void testReplaceFunctionName() { - String sql = - "select 公司名称,平均(注册资本),总部地点 from 互联网企业 where\n" - + "年营业额 >= 28800000000 and 最大(注册资本)>10000 \n" - + " group by 公司名称 having 平均(注册资本)>10000 order by \n" - + "平均(注册资本) desc limit 5"; + String sql = "select 公司名称,平均(注册资本),总部地点 from 互联网企业 where\n" + + "年营业额 >= 28800000000 and 最大(注册资本)>10000 \n" + + " group by 公司名称 having 平均(注册资本)>10000 order by \n" + "平均(注册资本) desc limit 5"; Map map = new HashMap<>(); map.put("平均", "avg"); map.put("最大", "max"); sql = SqlReplaceHelper.replaceFunction(sql, map); System.out.println(sql); - Assert.assertEquals( - "SELECT 公司名称, avg(注册资本), 总部地点 FROM 互联网企业 WHERE 年营业额 >= 28800000000 AND " - + "max(注册资本) > 10000 GROUP BY 公司名称 HAVING avg(注册资本) > 10000 ORDER BY avg(注册资本) DESC LIMIT 5", + Assert.assertEquals("SELECT 公司名称, avg(注册资本), 总部地点 FROM 互联网企业 WHERE 年营业额 >= 28800000000 AND " + + "max(注册资本) > 10000 GROUP BY 公司名称 HAVING avg(注册资本) > 10000 ORDER BY avg(注册资本) DESC LIMIT 5", sql); - sql = - "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" - + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; + sql = "select MONTH(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; Map functionMap = new HashMap<>(); functionMap.put("MONTH".toLowerCase(), "toMonth"); String replaceSql = SqlReplaceHelper.replaceFunction(sql, functionMap); @@ -490,9 +430,8 @@ class SqlReplaceHelperTest { + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", replaceSql); - sql = - "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" - + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; + sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " datediff('month', 数据日期, '2023-09-02') <= 6 group by MONTH(数据日期)"; replaceSql = SqlReplaceHelper.replaceFunction(sql, functionMap); Assert.assertEquals( @@ -500,59 +439,48 @@ class SqlReplaceHelperTest { + " datediff('month', 数据日期, '2023-09-02') <= 6 GROUP BY toMonth(数据日期)", replaceSql); - sql = - "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" - + " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)"; + sql = "select month(数据日期) as 月份, avg(访问次数) as 平均访问次数 from 内容库产品 where" + + " (datediff('month', 数据日期, '2023-09-02') <= 6) and 数据日期 = '2023-10-10' group by MONTH(数据日期)"; replaceSql = SqlReplaceHelper.replaceFunction(sql, functionMap); - Assert.assertEquals( - "SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" - + " (datediff('month', 数据日期, '2023-09-02') <= 6) AND " - + "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)", - replaceSql); + Assert.assertEquals("SELECT toMonth(数据日期) AS 月份, avg(访问次数) AS 平均访问次数 FROM 内容库产品 WHERE" + + " (datediff('month', 数据日期, '2023-09-02') <= 6) AND " + + "数据日期 = '2023-10-10' GROUP BY toMonth(数据日期)", replaceSql); } @Test void testReplaceAlias() { - String sql = - "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " - + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; + String sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "datediff('day', 数据日期, '2023-09-05') <= 3 group by 部门 order by 总访问次数 desc limit 10"; String replaceSql = SqlReplaceHelper.replaceAlias(sql); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " - + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", + Assert.assertEquals("SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "datediff('day', 数据日期, '2023-09-05') <= 3 GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", replaceSql); - sql = - "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " - + "group by 部门 order by 总访问次数 desc limit 10"; + sql = "select 部门, sum(访问次数) as 总访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 总访问次数 desc limit 10"; replaceSql = SqlReplaceHelper.replaceAlias(sql); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT 部门, sum(访问次数) FROM 超音数 WHERE " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " - + "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", - replaceSql); + Assert.assertEquals("SELECT 部门, sum(访问次数) FROM 超音数 WHERE " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) AND 数据日期 = '2023-10-10' " + + "GROUP BY 部门 ORDER BY sum(访问次数) DESC LIMIT 10", replaceSql); - sql = - "select 部门, sum(访问次数) as 访问次数 from 超音数 where " - + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " - + "group by 部门 order by 访问次数 desc limit 10"; + sql = "select 部门, sum(访问次数) as 访问次数 from 超音数 where " + + "(datediff('day', 数据日期, '2023-09-05') <= 3) and 数据日期 = '2023-10-10' " + + "group by 部门 order by 访问次数 desc limit 10"; replaceSql = SqlReplaceHelper.replaceAlias(sql); System.out.println(replaceSql); - Assert.assertEquals( - "SELECT 部门, sum(访问次数) AS 访问次数 FROM 超音数 WHERE (datediff('day', 数据日期, " - + "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY 访问次数 DESC LIMIT 10", + Assert.assertEquals("SELECT 部门, sum(访问次数) AS 访问次数 FROM 超音数 WHERE (datediff('day', 数据日期, " + + "'2023-09-05') <= 3) AND 数据日期 = '2023-10-10' GROUP BY 部门 ORDER BY 访问次数 DESC LIMIT 10", replaceSql); } @Test void testReplaceAggAliasOrderItem() { - String sql = - "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 " - + "GROUP BY 部门 ORDER BY SUM(访问次数) DESC LIMIT 10) AS top10"; + String sql = "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 " + + "GROUP BY 部门 ORDER BY SUM(访问次数) DESC LIMIT 10) AS top10"; String replaceSql = SqlReplaceHelper.replaceAggAliasOrderItem(sql); Assert.assertEquals( "SELECT SUM(访问次数) AS top10总播放量 FROM (SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数 " diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelperTest.java index 07a891afd..0c2719542 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectFunctionHelperTest.java @@ -10,39 +10,33 @@ class SqlSelectFunctionHelperTest { @Test void testHasAggregateFunction() throws JSQLParserException { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; boolean hasAggregateFunction = SqlSelectFunctionHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, true); - sql = - "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + sql = "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; hasAggregateFunction = SqlSelectFunctionHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, true); - sql = - "SELECT count(1) FROM s2 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice'" - + " AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; + sql = "SELECT count(1) FROM s2 WHERE sys_imp_date = '2023-08-08' AND user_id = 'alice'" + + " AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; hasAggregateFunction = SqlSelectFunctionHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, true); - sql = - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " - + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; + sql = "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"; hasAggregateFunction = SqlSelectFunctionHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, false); - sql = - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11'"; + sql = "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11'"; hasAggregateFunction = SqlSelectFunctionHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, false); - sql = - "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; + sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; hasAggregateFunction = SqlSelectFunctionHelper.hasAggregateFunction(sql); Assert.assertEquals(hasAggregateFunction, true); } @@ -50,33 +44,28 @@ class SqlSelectFunctionHelperTest { @Test void testHasFunction() throws JSQLParserException { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; boolean hasFunction = SqlSelectFunctionHelper.hasFunction(sql, "sum"); Assert.assertEquals(hasFunction, true); - sql = - "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + sql = "select 部门,count (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; hasFunction = SqlSelectFunctionHelper.hasFunction(sql, "count"); Assert.assertEquals(hasFunction, true); - sql = - "select 部门,count (*) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + sql = "select 部门,count (*) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; hasFunction = SqlSelectFunctionHelper.hasFunction(sql, "count"); Assert.assertEquals(hasFunction, true); - sql = - "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " - + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; + sql = "SELECT user_name, pv FROM t_34 WHERE sys_imp_date <= '2023-09-03' " + + "AND sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10"; hasFunction = SqlSelectFunctionHelper.hasFunction(sql, "sum"); Assert.assertEquals(hasFunction, false); - sql = - "select 部门,min (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + sql = "select 部门,min (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; hasFunction = SqlSelectFunctionHelper.hasFunction(sql, "min"); Assert.assertEquals(hasFunction, true); @@ -84,9 +73,8 @@ class SqlSelectFunctionHelperTest { @Test void testHasAsterisk() { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; Assert.assertEquals(SqlSelectFunctionHelper.hasAsterisk(sql), false); sql = "select * from 超音数 where 数据日期 = '2023-08-08' " + "and 用户 =alice and 发布日期 ='11'"; diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java index 44e841d6d..e5d68b762 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlSelectHelperTest.java @@ -13,132 +13,109 @@ class SqlSelectHelperTest { @Test void testGetWhereFilterExpression() { - Select selectStatement = - SqlSelectHelper.getSelect( - "select 用户名, 访问次数 from 超音数 where 用户名 in ('alice', 'lucy')"); + Select selectStatement = SqlSelectHelper + .getSelect("select 用户名, 访问次数 from 超音数 where 用户名 in ('alice', 'lucy')"); System.out.println(selectStatement); - List fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE " - + "sys_imp_date = '2023-08-08' AND YEAR(publish_date) = 2023 " - + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); + List fieldExpression = SqlSelectHelper + .getFilterExpression("SELECT department, user_id, field_a FROM s2 WHERE " + + "sys_imp_date = '2023-08-08' AND YEAR(publish_date) = 2023 " + + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + + " AND YEAR(publish_date) = 2023 " + " AND MONTH(publish_date) = 8" + + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" + + " AND YEAR(publish_date) = 2023 " + + " AND MONTH(publish_date) = 8 AND DAY(publish_date) =20 " + + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " + + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper + .getFilterExpression("SELECT department, user_id, field_a FROM s2 WHERE " + + "user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper + .getFilterExpression("SELECT department, user_id, field_a FROM s2 WHERE " + + "user_id = 'alice' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper + .getFilterExpression("SELECT department, user_id, field_a FROM s2 WHERE " + + "user_id like '%alice%' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression("SELECT department, pv FROM s2 WHERE " + + "user_id like '%alice%' AND publish_date > 10000 " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression("SELECT department, pv FROM s2 WHERE " + + "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression("SELECT department, pv FROM s2 WHERE " + + "(user_id like '%alice%' AND publish_date > 10000) and song_name in " + + "('七里香','晴天') and sys_imp_date = '2023-08-08' " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression("SELECT department, pv FROM s2 WHERE " + + "(user_id like '%alice%' AND publish_date > 10000) and song_name in (1,2) " + + "and sys_imp_date = '2023-08-08' " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(fieldExpression); + + fieldExpression = SqlSelectHelper.getFilterExpression("SELECT department, pv FROM s2 WHERE " + + "(user_id like '%alice%' AND publish_date > 10000) and 1 in (1) " + + "and sys_imp_date = '2023-08-08' " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); System.out.println(fieldExpression); fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " - + " AND YEAR(publish_date) = 2023 " - + " AND MONTH(publish_date) = 8" - + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); + SqlSelectHelper.getFilterExpression("SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 " + + "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9"); System.out.println(fieldExpression); - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" - + " AND YEAR(publish_date) = 2023 " - + " AND MONTH(publish_date) = 8 AND DAY(publish_date) =20 " - + " AND user_id = 'alice' ORDER BY pv DESC LIMIT 1"); - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " - + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " - + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08' " - + "AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE " - + "user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE " - + "user_id = 'alice' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, user_id, field_a FROM s2 WHERE " - + "user_id like '%alice%' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, pv FROM s2 WHERE " - + "user_id like '%alice%' AND publish_date > 10000 " - + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, pv FROM s2 WHERE " - + "(user_id like '%alice%' AND publish_date > 10000) and sys_imp_date = '2023-08-08' " - + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, pv FROM s2 WHERE " - + "(user_id like '%alice%' AND publish_date > 10000) and song_name in " - + "('七里香','晴天') and sys_imp_date = '2023-08-08' " - + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, pv FROM s2 WHERE " - + "(user_id like '%alice%' AND publish_date > 10000) and song_name in (1,2) " - + "and sys_imp_date = '2023-08-08' " - + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT department, pv FROM s2 WHERE " - + "(user_id like '%alice%' AND publish_date > 10000) and 1 in (1) " - + "and sys_imp_date = '2023-08-08' " - + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "SELECT sum(销量) / (SELECT sum(销量) FROM 营销月模型 " - + "WHERE MONTH(数据日期) = 9) FROM 营销月模型 WHERE 国家中文名 = '肯尼亚' AND MONTH(数据日期) = 9"); - - System.out.println(fieldExpression); - - fieldExpression = - SqlSelectHelper.getFilterExpression( - "select 等级, count(*) from 歌手 where 别名 = '港台' or 活跃区域 = '港台' and" - + " datediff('day', 数据日期, '2023-12-24') <= 0 group by 等级"); + fieldExpression = SqlSelectHelper.getFilterExpression( + "select 等级, count(*) from 歌手 where 别名 = '港台' or 活跃区域 = '港台' and" + + " datediff('day', 数据日期, '2023-12-24') <= 0 group by 等级"); System.out.println(fieldExpression); } @@ -146,66 +123,56 @@ class SqlSelectHelperTest { @Test void testGetAllFields() { - List allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); + List allFields = SqlSelectHelper.getAllSelectFields( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date = '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); Assert.assertEquals(allFields.size(), 6); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date >= '2023-08-08'" - + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); + allFields = SqlSelectHelper.getAllSelectFields( + "SELECT department, user_id, field_a FROM s2 WHERE sys_imp_date >= '2023-08-08'" + + " AND user_id = 'alice' AND publish_date = '11' ORDER BY pv DESC LIMIT 1"); Assert.assertEquals(allFields.size(), 6); - allFields = - SqlSelectHelper.getAllSelectFields( - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 = 'alice'" - + " and 发布日期 ='11' group by 部门 limit 1"); + allFields = SqlSelectHelper.getAllSelectFields( + "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' and 用户 = 'alice'" + + " and 发布日期 ='11' group by 部门 limit 1"); Assert.assertEquals(allFields.size(), 5); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10 "); + allFields = SqlSelectHelper.getAllSelectFields( + "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10 "); Assert.assertEquals(allFields.size(), 3); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"); + allFields = SqlSelectHelper.getAllSelectFields( + "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"); Assert.assertEquals(allFields.size(), 3); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT department, user_id, field_a FROM s2 WHERE " - + "(user_id = 'alice' AND publish_date = '11') and sys_imp_date " - + "= '2023-08-08' ORDER BY pv DESC LIMIT 1"); + allFields = SqlSelectHelper + .getAllSelectFields("SELECT department, user_id, field_a FROM s2 WHERE " + + "(user_id = 'alice' AND publish_date = '11') and sys_imp_date " + + "= '2023-08-08' ORDER BY pv DESC LIMIT 1"); Assert.assertEquals(allFields.size(), 6); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT * FROM CSpider WHERE (评分 < (SELECT min(评分) FROM CSpider WHERE 语种 = '英文' ))" - + " AND 数据日期 = '2023-10-12'"); + allFields = SqlSelectHelper.getAllSelectFields( + "SELECT * FROM CSpider WHERE (评分 < (SELECT min(评分) FROM CSpider WHERE 语种 = '英文' ))" + + " AND 数据日期 = '2023-10-12'"); Assert.assertEquals(allFields.size(), 3); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT sum(销量) / (SELECT sum(销量) FROM 营销 " - + "WHERE MONTH(数据日期) = 9) FROM 营销 WHERE 国家中文名 = '中国' AND MONTH(数据日期) = 9"); + allFields = SqlSelectHelper.getAllSelectFields("SELECT sum(销量) / (SELECT sum(销量) FROM 营销 " + + "WHERE MONTH(数据日期) = 9) FROM 营销 WHERE 国家中文名 = '中国' AND MONTH(数据日期) = 9"); Assert.assertEquals(allFields.size(), 3); - allFields = - SqlSelectHelper.getAllSelectFields( - "SELECT 用户, 页面 FROM 超音数用户部门 GROUP BY 用户, 页面 ORDER BY count(*) DESC"); + allFields = SqlSelectHelper.getAllSelectFields( + "SELECT 用户, 页面 FROM 超音数用户部门 GROUP BY 用户, 页面 ORDER BY count(*) DESC"); Assert.assertEquals(allFields.size(), 2); } @@ -213,9 +180,8 @@ class SqlSelectHelperTest { @Test void testGetSelectFields() { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " - + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08' " + + "and 用户 =alice and 发布日期 ='11' group by 部门 limit 1"; List selectFields = SqlSelectHelper.getSelectFields(sql); Assert.assertEquals(selectFields.contains("访问次数"), true); @@ -225,28 +191,25 @@ class SqlSelectHelperTest { @Test void testGetWhereFields() { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" - + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; List selectFields = SqlSelectHelper.getWhereFields(sql); Assert.assertEquals(selectFields.contains("发布日期"), true); Assert.assertEquals(selectFields.contains("数据日期"), true); Assert.assertEquals(selectFields.contains("用户"), true); - sql = - "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08'" - + " and 用户 = 'alice' and 发布日期 ='11' order by 访问次数 limit 1"; + sql = "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' order by 访问次数 limit 1"; selectFields = SqlSelectHelper.getWhereFields(sql); Assert.assertEquals(selectFields.contains("发布日期"), true); Assert.assertEquals(selectFields.contains("数据日期"), true); Assert.assertEquals(selectFields.contains("用户"), true); - sql = - "select 部门,用户 from 超音数 where" - + " (用户 = 'alice' and 发布日期 ='11') and 数据日期 = '2023-08-08' " - + "order by 访问次数 limit 1"; + sql = "select 部门,用户 from 超音数 where" + + " (用户 = 'alice' and 发布日期 ='11') and 数据日期 = '2023-08-08' " + + "order by 访问次数 limit 1"; selectFields = SqlSelectHelper.getWhereFields(sql); Assert.assertEquals(selectFields.contains("发布日期"), true); @@ -257,16 +220,14 @@ class SqlSelectHelperTest { @Test void testGetOrderByFields() { - String sql = - "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08'" - + " and 用户 = 'alice' and 发布日期 ='11' order by 访问次数 limit 1"; + String sql = "select 部门,用户 from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' order by 访问次数 limit 1"; List selectFields = SqlSelectHelper.getOrderByFields(sql); Assert.assertEquals(selectFields.contains("访问次数"), true); - sql = - "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10 "; + sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name ORDER BY sum(pv) DESC LIMIT 10 "; selectFields = SqlSelectHelper.getOrderByFields(sql); Assert.assertEquals(selectFields.contains("pv"), true); @@ -275,9 +236,8 @@ class SqlSelectHelperTest { @Test void testGetGroupByFields() { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" - + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; List selectFields = SqlSelectHelper.getGroupByFields(sql); Assert.assertEquals(selectFields.contains("部门"), true); @@ -286,9 +246,8 @@ class SqlSelectHelperTest { @Test void testGetHavingExpression() { - String sql = - "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " - + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; + String sql = "SELECT user_name FROM 超音数 WHERE sys_imp_date <= '2023-09-03' AND " + + "sys_imp_date >= '2023-08-04' GROUP BY user_name HAVING sum(pv) > 1000"; List leftExpressionList = SqlSelectHelper.getHavingExpression(sql); Assert.assertEquals(leftExpressionList.get(0).toString(), "sum(pv)"); @@ -297,9 +256,8 @@ class SqlSelectHelperTest { @Test void testGetAggregateFields() { - String sql = - "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" - + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from 超音数 where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; List selectFields = SqlSelectHelper.getAggregateFields(sql); Assert.assertEquals(selectFields.contains("访问次数"), true); } @@ -307,9 +265,8 @@ class SqlSelectHelperTest { @Test void testGetTableName() { - String sql = - "select 部门,sum (访问次数) from `超音数` where 数据日期 = '2023-08-08'" - + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; + String sql = "select 部门,sum (访问次数) from `超音数` where 数据日期 = '2023-08-08'" + + " and 用户 = 'alice' and 发布日期 ='11' group by 部门 limit 1"; String tableName = SqlSelectHelper.getTableName(sql); Assert.assertEquals(tableName, "超音数"); } @@ -317,9 +274,8 @@ class SqlSelectHelperTest { @Test void testGetPureSelectFields() { - String sql = - "select TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` " - + "where 数据日期 = '2023-08-08' and 用户 = 'alice'"; + String sql = "select TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` " + + "where 数据日期 = '2023-08-08' and 用户 = 'alice'"; List selectFields = SqlSelectHelper.gePureSelectFields(sql); Assert.assertEquals(selectFields.size(), 0); @@ -327,9 +283,8 @@ class SqlSelectHelperTest { selectFields = SqlSelectHelper.gePureSelectFields(sql); Assert.assertEquals(selectFields.size(), 2); - sql = - "select 发布日期,数据日期,TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` where " - + "数据日期 = '2023-08-08' and 用户 = 'alice'"; + sql = "select 发布日期,数据日期,TIMESTAMPDIFF(MONTH, 发布日期, '2018-06-01') from `超音数` where " + + "数据日期 = '2023-08-08' and 用户 = 'alice'"; selectFields = SqlSelectHelper.gePureSelectFields(sql); Assert.assertEquals(selectFields.size(), 2); } diff --git a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java index b24eb37b3..e455078bb 100644 --- a/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/jsqlparser/SqlValidHelperTest.java @@ -15,19 +15,15 @@ class SqlValidHelperTest { sql2 = "SELECT d,c,b,a FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); - sql1 = - "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; + sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; - sql2 = - "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; + sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); - sql1 = - "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; + sql1 = "SELECT a,sum(b),sum(c),sum(d) FROM table1 WHERE column1 = 1 AND column2 = 2 group by a order by a"; - sql2 = - "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; + sql2 = "SELECT sum(d),sum(c),sum(b),a FROM table1 WHERE column2 = 2 AND column1 = 1 group by a order by a"; Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); @@ -35,35 +31,13 @@ class SqlValidHelperTest { sql2 = "SELECT d,c,b,f FROM table1 WHERE column2 = 2 AND column1 = 1 order by a"; Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), false); - sql1 = - "SELECT\n" - + "页面,\n" - + "SUM(访问次数)\n" - + "FROM\n" - + "超音数\n" - + "WHERE\n" - + "数据日期 >= '2023-10-26'\n" - + "AND 数据日期 <= '2023-11-09'\n" - + "AND department = \"HR\"\n" - + "GROUP BY\n" - + "页面\n" - + "LIMIT\n" - + "365"; + sql1 = "SELECT\n" + "页面,\n" + "SUM(访问次数)\n" + "FROM\n" + "超音数\n" + "WHERE\n" + + "数据日期 >= '2023-10-26'\n" + "AND 数据日期 <= '2023-11-09'\n" + + "AND department = \"HR\"\n" + "GROUP BY\n" + "页面\n" + "LIMIT\n" + "365"; - sql2 = - "SELECT\n" - + "页面,\n" - + "SUM(访问次数)\n" - + "FROM\n" - + "超音数\n" - + "WHERE\n" - + "department = \"HR\"\n" - + "AND 数据日期 >= '2023-10-26'\n" - + "AND 数据日期 <= '2023-11-09'\n" - + "GROUP BY\n" - + "页面\n" - + "LIMIT\n" - + "365"; + sql2 = "SELECT\n" + "页面,\n" + "SUM(访问次数)\n" + "FROM\n" + "超音数\n" + "WHERE\n" + + "department = \"HR\"\n" + "AND 数据日期 >= '2023-10-26'\n" + + "AND 数据日期 <= '2023-11-09'\n" + "GROUP BY\n" + "页面\n" + "LIMIT\n" + "365"; Assert.assertEquals(SqlValidHelper.equals(sql1, sql2), true); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java index 8c79066c6..76b9c9f87 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/DataSetSchema.java @@ -65,10 +65,8 @@ public class DataSetSchema { List allElements = new ArrayList<>(); allElements.addAll(getDimensions()); allElements.addAll(getMetrics()); - return allElements.stream() - .collect( - Collectors.toMap( - SchemaElement::getBizName, SchemaElement::getName, (k1, k2) -> k1)); + return allElements.stream().collect(Collectors.toMap(SchemaElement::getBizName, + SchemaElement::getName, (k1, k2) -> k1)); } public TimeDefaultConfig getTagTypeTimeDefaultConfig() { @@ -104,16 +102,13 @@ public class DataSetSchema { || Objects.isNull(detailTypeDefaultConfig.getDefaultDisplayInfo())) { return new ArrayList<>(); } - if (CollectionUtils.isNotEmpty( - detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) { + if (CollectionUtils + .isNotEmpty(detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) { return detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds().stream() - .map( - id -> { - SchemaElement metric = getElement(SchemaElementType.METRIC, id); - return metric; - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + .map(id -> { + SchemaElement metric = getElement(SchemaElementType.METRIC, id); + return metric; + }).filter(Objects::nonNull).collect(Collectors.toList()); } return new ArrayList<>(); } @@ -124,11 +119,10 @@ public class DataSetSchema { || Objects.isNull(detailTypeDefaultConfig.getDefaultDisplayInfo())) { return new ArrayList<>(); } - if (CollectionUtils.isNotEmpty( - detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) { + if (CollectionUtils + .isNotEmpty(detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) { return detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream() - .map(id -> getElement(SchemaElementType.DIMENSION, id)) - .filter(Objects::nonNull) + .map(id -> getElement(SchemaElementType.DIMENSION, id)).filter(Objects::nonNull) .collect(Collectors.toList()); } return new ArrayList<>(); diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dim.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dim.java index 16d1a6e8d..82f9612c1 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dim.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/Dim.java @@ -43,14 +43,8 @@ public class Dim { this.isTag = isTag; } - public Dim( - String name, - String type, - String expr, - String dateFormat, - DimensionTimeTypeParams typeParams, - Integer isCreateDimension, - String bizName) { + public Dim(String name, String type, String expr, String dateFormat, + DimensionTimeTypeParams typeParams, Integer isCreateDimension, String bizName) { this.name = name; this.type = type; this.expr = expr; @@ -61,14 +55,8 @@ public class Dim { } public static Dim getDefault() { - return new Dim( - "日期", - "time", - "2023-05-28", - Constants.DAY_FORMAT, - new DimensionTimeTypeParams("true", "day"), - 0, - "imp_date"); + return new Dim("日期", "time", "2023-05-28", Constants.DAY_FORMAT, + new DimensionTimeTypeParams("true", "day"), 0, "imp_date"); } public String getFieldName() { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ItemDateFilter.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ItemDateFilter.java index 9cf0aa05a..86335540b 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ItemDateFilter.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ItemDateFilter.java @@ -13,5 +13,6 @@ import java.util.List; public class ItemDateFilter { private List itemIds; - @NonNull private String type; + @NonNull + private String type; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/MetaFilter.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/MetaFilter.java index c4e369468..22dce4506 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/MetaFilter.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/MetaFilter.java @@ -51,34 +51,18 @@ public class MetaFilter { return false; } MetaFilter that = (MetaFilter) o; - return Objects.equal(id, that.id) - && Objects.equal(name, that.name) - && Objects.equal(bizName, that.bizName) - && Objects.equal(createdBy, that.createdBy) - && Objects.equal(modelIds, that.modelIds) - && Objects.equal(domainId, that.domainId) + return Objects.equal(id, that.id) && Objects.equal(name, that.name) + && Objects.equal(bizName, that.bizName) && Objects.equal(createdBy, that.createdBy) + && Objects.equal(modelIds, that.modelIds) && Objects.equal(domainId, that.domainId) && Objects.equal(dataSetId, that.dataSetId) && Objects.equal(sensitiveLevel, that.sensitiveLevel) - && Objects.equal(status, that.status) - && Objects.equal(key, that.key) - && Objects.equal(ids, that.ids) - && Objects.equal(fieldsDepend, that.fieldsDepend); + && Objects.equal(status, that.status) && Objects.equal(key, that.key) + && Objects.equal(ids, that.ids) && Objects.equal(fieldsDepend, that.fieldsDepend); } @Override public int hashCode() { - return Objects.hashCode( - id, - name, - bizName, - createdBy, - modelIds, - domainId, - dataSetId, - sensitiveLevel, - status, - key, - ids, - fieldsDepend); + return Objects.hashCode(id, name, bizName, createdBy, modelIds, domainId, dataSetId, + sensitiveLevel, status, key, ids, fieldsDepend); } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java index aa46a1b54..051c97fbe 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/ModelDetail.java @@ -55,30 +55,19 @@ public class ModelDetail { List fieldList = Lists.newArrayList(); // Compatible with older versions if (!CollectionUtils.isEmpty(identifiers)) { - fieldList.addAll( - identifiers.stream() - .map( - identify -> - Field.builder() - .fieldName(identify.getFieldName()) - .build()) - .collect(Collectors.toSet())); + fieldList.addAll(identifiers.stream() + .map(identify -> Field.builder().fieldName(identify.getFieldName()).build()) + .collect(Collectors.toSet())); } if (!CollectionUtils.isEmpty(dimensions)) { - fieldList.addAll( - dimensions.stream() - .map(dim -> Field.builder().fieldName(dim.getFieldName()).build()) - .collect(Collectors.toSet())); + fieldList.addAll(dimensions.stream() + .map(dim -> Field.builder().fieldName(dim.getFieldName()).build()) + .collect(Collectors.toSet())); } if (!CollectionUtils.isEmpty(measures)) { - fieldList.addAll( - measures.stream() - .map( - measure -> - Field.builder() - .fieldName(measure.getFieldName()) - .build()) - .collect(Collectors.toSet())); + fieldList.addAll(measures.stream() + .map(measure -> Field.builder().fieldName(measure.getFieldName()).build()) + .collect(Collectors.toSet())); } return fieldList; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java index 7f92abb19..d8fe8d68a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/QueryDataType.java @@ -1,8 +1,5 @@ package com.tencent.supersonic.headless.api.pojo; public enum QueryDataType { - METRIC, - DIMENSION, - TAG, - ALL + METRIC, DIMENSION, TAG, ALL } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RuleInfo.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RuleInfo.java index 6a8b48a50..f2a88fe9c 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RuleInfo.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/RuleInfo.java @@ -12,8 +12,6 @@ public class RuleInfo { public enum Mode { /** BEFORE, some days ago RECENT, the last few days EXIST, there was some information */ - BEFORE, - RECENT, - EXIST + BEFORE, RECENT, EXIST } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java index 7912be900..729b37a47 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElement.java @@ -40,7 +40,8 @@ public class SchemaElement implements Serializable { private int isTag; private String description; private boolean descriptionMapped; - @Builder.Default private Map extInfo = new HashMap<>(); + @Builder.Default + private Map extInfo = new HashMap<>(); private DimensionTimeTypeParams typeParams; @Override @@ -53,8 +54,7 @@ public class SchemaElement implements Serializable { } SchemaElement schemaElement = (SchemaElement) o; return Objects.equal(dataSetId, schemaElement.dataSetId) - && Objects.equal(id, schemaElement.id) - && Objects.equal(name, schemaElement.name) + && Objects.equal(id, schemaElement.id) && Objects.equal(name, schemaElement.name) && Objects.equal(bizName, schemaElement.bizName) && Objects.equal(type, schemaElement.type); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementType.java index 1c08c8949..1c005d39c 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaElementType.java @@ -1,13 +1,5 @@ package com.tencent.supersonic.headless.api.pojo; public enum SchemaElementType { - DATASET, - METRIC, - DIMENSION, - VALUE, - ENTITY, - ID, - DATE, - TAG, - TERM + DATASET, METRIC, DIMENSION, VALUE, ENTITY, ID, DATE, TAG, TERM } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaItem.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaItem.java index 46bb86b96..bf954b8cf 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaItem.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SchemaItem.java @@ -40,10 +40,8 @@ public class SchemaItem extends RecordInfo { return false; } SchemaItem that = (SchemaItem) o; - return Objects.equal(id, that.id) - && Objects.equal(name, that.name) - && Objects.equal(bizName, that.bizName) - && typeEnum == that.typeEnum; + return Objects.equal(id, that.id) && Objects.equal(name, that.name) + && Objects.equal(bizName, that.bizName) && typeEnum == that.typeEnum; } @Override diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java index f5a2a3d1a..00f7aef8d 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/SemanticSchema.java @@ -59,12 +59,8 @@ public class SemanticSchema implements Serializable { } public Map getDataSetIdToName() { - return dataSetSchemaList.stream() - .collect( - Collectors.toMap( - a -> a.getDataSet().getId(), - a -> a.getDataSet().getName(), - (k1, k2) -> k1)); + return dataSetSchemaList.stream().collect(Collectors.toMap(a -> a.getDataSet().getId(), + a -> a.getDataSet().getName(), (k1, k2) -> k1)); } public List getDimensionValues() { @@ -124,16 +120,15 @@ public class SemanticSchema implements Serializable { return terms; } - private List getElementsByDataSetId( - Long dataSetId, List elements) { + private List getElementsByDataSetId(Long dataSetId, + List elements) { return elements.stream() .filter(schemaElement -> dataSetId.equals(schemaElement.getDataSetId())) .collect(Collectors.toList()); } private Optional getElementsById(Long id, List elements) { - return elements.stream() - .filter(schemaElement -> id.equals(schemaElement.getId())) + return elements.stream().filter(schemaElement -> id.equals(schemaElement.getId())) .findFirst(); } @@ -143,13 +138,9 @@ public class SemanticSchema implements Serializable { } public QueryConfig getQueryConfig(Long dataSetId) { - DataSetSchema first = - dataSetSchemaList.stream() - .filter( - dataSetSchema -> - dataSetId.equals(dataSetSchema.getDataSet().getDataSetId())) - .findFirst() - .orElse(null); + DataSetSchema first = dataSetSchemaList.stream().filter( + dataSetSchema -> dataSetId.equals(dataSetSchema.getDataSet().getDataSetId())) + .findFirst().orElse(null); if (Objects.nonNull(first)) { return first.getQueryConfig(); } @@ -166,10 +157,8 @@ public class SemanticSchema implements Serializable { if (CollectionUtils.isEmpty(dataSetSchemaList)) { return new HashMap<>(); } - return dataSetSchemaList.stream() - .collect( - Collectors.toMap( - dataSetSchema -> dataSetSchema.getDataSet().getDataSetId(), - dataSetSchema -> dataSetSchema)); + return dataSetSchemaList.stream().collect( + Collectors.toMap(dataSetSchema -> dataSetSchema.getDataSet().getDataSetId(), + dataSetSchema -> dataSetSchema)); } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AggOption.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AggOption.java index 72afcbb66..132fad1d4 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AggOption.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AggOption.java @@ -5,10 +5,7 @@ package com.tencent.supersonic.headless.api.pojo.enums; * Aggregation DEFAULT: will use the aggregation method define in the model */ public enum AggOption { - NATIVE, - AGGREGATION, - OUTER, - DEFAULT; + NATIVE, AGGREGATION, OUTER, DEFAULT; public static AggOption getAggregation(boolean isNativeQuery) { return isNativeQuery ? NATIVE : AGGREGATION; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AppStatus.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AppStatus.java index 4bf590529..3bf8ce0ed 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AppStatus.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/AppStatus.java @@ -1,11 +1,7 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum AppStatus { - INIT(0), - ONLINE(1), - OFFLINE(2), - DELETED(3), - UNKNOWN(4); + INIT(0), ONLINE(1), OFFLINE(2), DELETED(3), UNKNOWN(4); private Integer code; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java index c30969613..429a3bcdb 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ChatWorkflowState.java @@ -1,10 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum ChatWorkflowState { - MAPPING, - PARSING, - CORRECTING, - TRANSLATING, - PROCESSING, - FINISHED + MAPPING, PARSING, CORRECTING, TRANSLATING, PROCESSING, FINISHED } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/CostType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/CostType.java index 95b0952e6..dbada2bb9 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/CostType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/CostType.java @@ -1,10 +1,7 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum CostType { - MAPPER(1, "mapper"), - PARSER(2, "parser"), - QUERY(3, "query"), - PROCESSOR(4, "processor"); + MAPPER(1, "mapper"), PARSER(2, "parser"), QUERY(3, "query"), PROCESSOR(4, "processor"); private Integer type; private String name; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DataType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DataType.java index 51779fc5a..38a3b0515 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DataType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DataType.java @@ -12,57 +12,27 @@ public enum DataType { ORACLE("oracle", "oracle", "oracle.jdbc.driver.OracleDriver", "\"", "\"", "\"", "\""), - SQLSERVER( - "sqlserver", - "sqlserver", - "com.microsoft.sqlserver.jdbc.SQLServerDriver", - "\"", - "\"", - "\"", - "\""), + SQLSERVER("sqlserver", "sqlserver", "com.microsoft.sqlserver.jdbc.SQLServerDriver", "\"", "\"", + "\"", "\""), H2("h2", "h2", "org.h2.Driver", "`", "`", "\"", "\""), - PHOENIX( - "phoenix", - "hbase phoenix", - "org.apache.phoenix.jdbc.PhoenixDriver", - "", - "", - "\"", + PHOENIX("phoenix", "hbase phoenix", "org.apache.phoenix.jdbc.PhoenixDriver", "", "", "\"", "\""), MONGODB("mongo", "mongodb", "mongodb.jdbc.MongoDriver", "`", "`", "\"", "\""), - ELASTICSEARCH( - "elasticsearch", - "elasticsearch", - "com.amazon.opendistroforelasticsearch.jdbc.Driver", - "", - "", - "'", - "'"), + ELASTICSEARCH("elasticsearch", "elasticsearch", + "com.amazon.opendistroforelasticsearch.jdbc.Driver", "", "", "'", "'"), PRESTO("presto", "presto", "com.facebook.presto.jdbc.PrestoDriver", "\"", "\"", "\"", "\""), MOONBOX("moonbox", "moonbox", "moonbox.jdbc.MbDriver", "`", "`", "`", "`"), - CASSANDRA( - "cassandra", - "cassandra", - "com.github.adejanovski.cassandra.jdbc.CassandraDriver", - "", - "", - "'", - "'"), + CASSANDRA("cassandra", "cassandra", "com.github.adejanovski.cassandra.jdbc.CassandraDriver", "", + "", "'", "'"), - CLICKHOUSE( - "clickhouse", - "clickhouse", - "ru.yandex.clickhouse.ClickHouseDriver", - "", - "", - "\"", + CLICKHOUSE("clickhouse", "clickhouse", "ru.yandex.clickhouse.ClickHouseDriver", "", "", "\"", "\""), KYLIN("kylin", "kylin", "org.apache.kylin.jdbc.Driver", "\"", "\"", "\"", "\""), @@ -85,14 +55,8 @@ public enum DataType { private String aliasPrefix; private String aliasSuffix; - DataType( - String feature, - String desc, - String driver, - String keywordPrefix, - String keywordSuffix, - String aliasPrefix, - String aliasSuffix) { + DataType(String feature, String desc, String driver, String keywordPrefix, String keywordSuffix, + String aliasPrefix, String aliasSuffix) { this.feature = feature; this.desc = desc; this.driver = driver; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java index ba320e389..6362693a6 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/DimensionType.java @@ -1,10 +1,7 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum DimensionType { - categorical, - time, - partition_time, - identify; + categorical, time, partition_time, identify; public static boolean isTimeDimension(String type) { try { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/IdentifyType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/IdentifyType.java index dbc65533b..d0492030a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/IdentifyType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/IdentifyType.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum IdentifyType { - primary, - foreign, + primary, foreign, } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java index 33e0aa156..f8f07eae8 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MapModeEnum.java @@ -1,9 +1,8 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum MapModeEnum { - STRICT(0), - MODERATE(2), - LOOSE(4); + STRICT(0), MODERATE(2), LOOSE(4); + public int threshold; MapModeEnum(Integer threshold) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricDefineType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricDefineType.java index 611447d23..0edda32b0 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricDefineType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricDefineType.java @@ -1,7 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum MetricDefineType { - FIELD, - MEASURE, - METRIC + FIELD, MEASURE, METRIC } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricType.java index cc6e6c564..c5cc84035 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/MetricType.java @@ -7,8 +7,7 @@ import java.util.List; import java.util.Objects; public enum MetricType { - ATOMIC, - DERIVED; + ATOMIC, DERIVED; public static MetricType of(String src) { for (MetricType metricType : MetricType.values()) { @@ -24,8 +23,8 @@ public enum MetricType { return Objects.nonNull(metricType) && metricType.equals(DERIVED); } - public static Boolean isDerived( - MetricDefineType metricDefineType, MetricDefineByMeasureParams typeParams) { + public static Boolean isDerived(MetricDefineType metricDefineType, + MetricDefineByMeasureParams typeParams) { if (MetricDefineType.METRIC.equals(metricDefineType)) { return true; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelDefineType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelDefineType.java index d3a24d191..27f1ecec8 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelDefineType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelDefineType.java @@ -5,8 +5,7 @@ package com.tencent.supersonic.headless.api.pojo.enums; * dbName.tableName */ public enum ModelDefineType { - SQL_QUERY("sql_query"), - TABLE_QUERY("table_query"); + SQL_QUERY("sql_query"), TABLE_QUERY("table_query"); private String name; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelSourceType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelSourceType.java index e85928e09..e5746d909 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelSourceType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/ModelSourceType.java @@ -7,9 +7,7 @@ import java.util.Objects; * ZIPPER: table with slowly changing dimension */ public enum ModelSourceType { - FULL, - PARTITION, - ZIPPER; + FULL, PARTITION, ZIPPER; public static ModelSourceType of(String src) { for (ModelSourceType modelSourceTypeEnum : ModelSourceType.values()) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/QueryRuleType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/QueryRuleType.java index 3bde82b16..27b1971ae 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/QueryRuleType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/QueryRuleType.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum QueryRuleType { - ADD_DATE, - ADD_SELECT + ADD_DATE, ADD_SELECT } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java index 1f88057ea..4971eca3a 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SchemaType.java @@ -1,6 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum SchemaType { - DATASET, - MODEL + DATASET, MODEL } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SemanticType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SemanticType.java index b49d147a9..bcd72bf13 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SemanticType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/SemanticType.java @@ -1,8 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum SemanticType { - CATEGORY, - ID, - DATE, - NUMBER + CATEGORY, ID, DATE, NUMBER } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagDefineType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagDefineType.java index 4257c9b36..af3655837 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagDefineType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagDefineType.java @@ -1,8 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum TagDefineType { - FIELD, - DIMENSION, - METRIC, - TAG + FIELD, DIMENSION, METRIC, TAG } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagType.java index c474056c2..26010f66e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/TagType.java @@ -3,8 +3,7 @@ package com.tencent.supersonic.headless.api.pojo.enums; import java.util.Objects; public enum TagType { - ATOMIC, - DERIVED; + ATOMIC, DERIVED; public static TagType of(String src) { for (TagType tagType : TagType.values()) { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java index 59fdd33fc..f79503de7 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/enums/VariableValueType.java @@ -1,7 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.enums; public enum VariableValueType { - STRING, - NUMBER, - EXPR + STRING, NUMBER, EXPR } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DateInfoReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DateInfoReq.java index 61122aea7..990a767d2 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DateInfoReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DateInfoReq.java @@ -22,8 +22,8 @@ public class DateInfoReq { private String datePeriod; private List unavailableDateList = new ArrayList<>(); - public DateInfoReq( - String type, Long itemId, String dateFormat, String startDate, String endDate) { + public DateInfoReq(String type, Long itemId, String dateFormat, String startDate, + String endDate) { this.type = type; this.itemId = itemId; this.dateFormat = dateFormat; @@ -31,13 +31,8 @@ public class DateInfoReq { this.endDate = endDate; } - public DateInfoReq( - String type, - Long itemId, - String dateFormat, - String startDate, - String endDate, - List unavailableDateList) { + public DateInfoReq(String type, Long itemId, String dateFormat, String startDate, + String endDate, List unavailableDateList) { this.type = type; this.itemId = itemId; this.dateFormat = dateFormat; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictItemReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictItemReq.java index 3427f9ab6..59ebdec5e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictItemReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictItemReq.java @@ -11,10 +11,13 @@ import lombok.Data; public class DictItemReq { private Long id; - @NotNull private TypeEnums type; - @NotNull private Long itemId; + @NotNull + private TypeEnums type; + @NotNull + private Long itemId; private ItemValueConfig config; /** ONLINE - 正常更新 OFFLINE - 停止更新,但字典文件不删除 DELETED - 停止更新,且删除字典文件 */ - @NotNull private StatusEnum status; + @NotNull + private StatusEnum status; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictSingleTaskReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictSingleTaskReq.java index 4b4a7a4c1..00703a8eb 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictSingleTaskReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DictSingleTaskReq.java @@ -9,6 +9,8 @@ import lombok.Data; @Data @Builder public class DictSingleTaskReq { - @NotNull private TypeEnums type; - @NotNull private Long itemId; + @NotNull + private TypeEnums type; + @NotNull + private Long itemId; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java index 61d326956..879f2c6b3 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/DimensionValueReq.java @@ -13,13 +13,15 @@ public class DimensionValueReq { private Integer agentId; - @NotNull private Long elementID; + @NotNull + private Long elementID; private Long modelId; private String bizName; - @NotNull private String value; + @NotNull + private String value; private Set dataSetIds; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ItemValueReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ItemValueReq.java index 297e3bf7d..d3b6179ac 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ItemValueReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/ItemValueReq.java @@ -10,7 +10,8 @@ import lombok.ToString; @ToString public class ItemValueReq { - @NotNull private Long id; + @NotNull + private Long id; private DateConf dateConf; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java index 2ef1a3e5b..bdfbc0427 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryFilter.java @@ -30,10 +30,8 @@ public class QueryFilter { return false; } QueryFilter that = (QueryFilter) o; - return Objects.equal(bizName, that.bizName) - && Objects.equal(name, that.name) - && operator == that.operator - && Objects.equal(value, that.value) + return Objects.equal(bizName, that.bizName) && Objects.equal(name, that.name) + && operator == that.operator && Objects.equal(value, that.value) && Objects.equal(elementID, that.elementID) && Objects.equal(function, that.function); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryRuleReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryRuleReq.java index a5b3be1b2..ba935f634 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryRuleReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryRuleReq.java @@ -23,10 +23,12 @@ public class QueryRuleReq extends SchemaItem { private Integer priority = 1; /** 规则类型 */ - @NotNull private QueryRuleType ruleType; + @NotNull + private QueryRuleType ruleType; /** 具体规则信息 */ - @NotNull private RuleInfo rule; + @NotNull + private RuleInfo rule; /** 规则输出信息 */ private ActionInfo action; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java index 539bae282..35cddca71 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/QueryStructReq.java @@ -54,10 +54,8 @@ public class QueryStructReq extends SemanticQueryReq { public List getGroups() { if (!CollectionUtils.isEmpty(this.groups)) { - this.groups = - groups.stream() - .filter(group -> !StringUtils.isEmpty(group)) - .collect(Collectors.toList()); + this.groups = groups.stream().filter(group -> !StringUtils.isEmpty(group)) + .collect(Collectors.toList()); } if (CollectionUtils.isEmpty(this.groups)) { @@ -195,8 +193,8 @@ public class QueryStructReq extends SemanticQueryReq { return selectItems; } - private SelectItem buildAggregatorSelectItem( - Aggregator aggregator, QueryStructReq queryStructReq) { + private SelectItem buildAggregatorSelectItem(Aggregator aggregator, + QueryStructReq queryStructReq) { String columnName = aggregator.getColumn(); if (queryStructReq.getQueryType().isNativeAggQuery()) { return new SelectItem(new Column(columnName)); @@ -213,10 +211,8 @@ public class QueryStructReq extends SemanticQueryReq { } function.setParameters(new ExpressionList(new Column(columnName))); SelectItem selectExpressionItem = new SelectItem(function); - String alias = - StringUtils.isNotBlank(aggregator.getAlias()) - ? aggregator.getAlias() - : columnName; + String alias = StringUtils.isNotBlank(aggregator.getAlias()) ? aggregator.getAlias() + : columnName; selectExpressionItem.setAlias(new Alias(alias)); return selectExpressionItem; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagBatchCreateReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagBatchCreateReq.java index 3bd3186a2..e18f2d0e8 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagBatchCreateReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagBatchCreateReq.java @@ -11,7 +11,8 @@ import java.util.List; @ToString @Data public class TagBatchCreateReq { - @NotNull private Long modelId; + @NotNull + private Long modelId; private SchemaElementType type; private List itemIds; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagObjectReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagObjectReq.java index a900bed59..ebabd7c9e 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagObjectReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagObjectReq.java @@ -13,7 +13,8 @@ import java.util.Objects; @Data public class TagObjectReq extends SchemaItem { - @NotNull private Long domainId; + @NotNull + private Long domainId; private Map ext = new HashMap<>(); public String getExtJson() { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagReq.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagReq.java index 135a1a732..bad5c03c1 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagReq.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/request/TagReq.java @@ -11,7 +11,9 @@ public class TagReq extends RecordInfo { private Long id; - @NotNull private TagDefineType tagDefineType; + @NotNull + private TagDefineType tagDefineType; - @NotNull private Long itemId; + @NotNull + private Long itemId; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetResp.java index f8f9d0620..82532f46d 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DataSetResp.java @@ -35,30 +35,23 @@ public class DataSetResp extends SchemaItem { private List allDimensions = new ArrayList<>(); public List metricIds() { - return getDataSetModelConfigs().stream() - .map(DataSetModelConfig::getMetrics) - .flatMap(Collection::stream) - .collect(Collectors.toList()); + return getDataSetModelConfigs().stream().map(DataSetModelConfig::getMetrics) + .flatMap(Collection::stream).collect(Collectors.toList()); } public List dimensionIds() { - return getDataSetModelConfigs().stream() - .map(DataSetModelConfig::getDimensions) - .flatMap(Collection::stream) - .collect(Collectors.toList()); + return getDataSetModelConfigs().stream().map(DataSetModelConfig::getDimensions) + .flatMap(Collection::stream).collect(Collectors.toList()); } public List getAllModels() { - return getDataSetModelConfigs().stream() - .map(DataSetModelConfig::getId) + return getDataSetModelConfigs().stream().map(DataSetModelConfig::getId) .collect(Collectors.toList()); } public List getAllIncludeAllModels() { - return getDataSetModelConfigs().stream() - .filter(DataSetModelConfig::getIncludesAll) - .map(DataSetModelConfig::getId) - .collect(Collectors.toList()); + return getDataSetModelConfigs().stream().filter(DataSetModelConfig::getIncludesAll) + .map(DataSetModelConfig::getId).collect(Collectors.toList()); } private List getDataSetModelConfigs() { diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DictItemResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DictItemResp.java index f2ea41736..d7f52d657 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DictItemResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/DictItemResp.java @@ -18,12 +18,15 @@ public class DictItemResp { private String bizName; - @NotNull private TypeEnums type; - @NotNull private Long itemId; + @NotNull + private TypeEnums type; + @NotNull + private Long itemId; private ItemValueConfig config; /** ONLINE - 正常更新 OFFLINE - 停止更新,但字典文件不删除 DELETED - 停止更新,且删除字典文件 */ - @NotNull private StatusEnum status; + @NotNull + private StatusEnum status; public String getNature() { return UNDERLINE + modelId + UNDERLINE + itemId; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MetricResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MetricResp.java index 0ad44c364..5f1397ee8 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MetricResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/MetricResp.java @@ -83,8 +83,7 @@ public class MetricResp extends SchemaItem { return ""; } return relateDimension.getDrillDownDimensions().stream() - .map(DrillDownDimension::getDimensionId) - .map(String::valueOf) + .map(DrillDownDimension::getDimensionId).map(String::valueOf) .collect(Collectors.joining(",")); } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelResp.java index 3cc475def..a540ea6a9 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelResp.java @@ -89,10 +89,8 @@ public class ModelResp extends SchemaItem { return fieldSet; } if (!CollectionUtils.isEmpty(modelDetail.getFields())) { - fieldSet.addAll( - modelDetail.getFields().stream() - .map(Field::getFieldName) - .collect(Collectors.toSet())); + fieldSet.addAll(modelDetail.getFields().stream().map(Field::getFieldName) + .collect(Collectors.toSet())); } return fieldSet; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelSchemaResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelSchemaResp.java index c24d9ca17..17ff6a8f3 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelSchemaResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ModelSchemaResp.java @@ -25,11 +25,10 @@ public class ModelSchemaResp extends ModelResp { return Sets.newHashSet(); } else { Set modelClusterSet = new HashSet(); - this.modelRelas.forEach( - (modelRela) -> { - modelClusterSet.add(modelRela.getToModelId()); - modelClusterSet.add(modelRela.getFromModelId()); - }); + this.modelRelas.forEach((modelRela) -> { + modelClusterSet.add(modelRela.getToModelId()); + modelClusterSet.add(modelRela.getFromModelId()); + }); return modelClusterSet; } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java index 367ab63ff..1870deedf 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/ParseResp.java @@ -18,9 +18,7 @@ public class ParseResp { private ParseTimeCostResp parseTimeCost = new ParseTimeCostResp(); public enum ParseState { - COMPLETED, - PENDING, - FAILED + COMPLETED, PENDING, FAILED } public ParseResp(String queryText) { @@ -29,10 +27,9 @@ public class ParseResp { } public List getSelectedParses() { - selectedParses = - selectedParses.stream() - .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) - .collect(Collectors.toList()); + selectedParses = selectedParses.stream() + .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) + .collect(Collectors.toList()); generateParseInfoId(selectedParses); return selectedParses; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryRuleResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryRuleResp.java index 2814f3919..68b88af94 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryRuleResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryRuleResp.java @@ -20,10 +20,12 @@ public class QueryRuleResp extends SchemaItem { private Integer priority = 1; /** 规则类型 */ - @NotNull private QueryRuleType ruleType; + @NotNull + private QueryRuleType ruleType; /** 具体规则信息 */ - @NotNull private RuleInfo rule; + @NotNull + private RuleInfo rule; /** 规则输出信息 */ private ActionInfo action; diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryState.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryState.java index 9b4c39703..ee266b1ff 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryState.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/QueryState.java @@ -1,8 +1,5 @@ package com.tencent.supersonic.headless.api.pojo.response; public enum QueryState { - SUCCESS, - SEARCH_EXCEPTION, - EMPTY, - INVALID; + SUCCESS, SEARCH_EXCEPTION, EMPTY, INVALID; } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticQueryResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticQueryResp.java index 183ba8d7a..112bdd7af 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticQueryResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticQueryResp.java @@ -28,10 +28,8 @@ public class SemanticQueryResp extends QueryResult> { } public List getDimensionColumns() { - return columns.stream() - .filter( - queryColumn -> - !SemanticType.NUMBER.name().equals(queryColumn.getShowType())) + return columns.stream().filter( + queryColumn -> !SemanticType.NUMBER.name().equals(queryColumn.getShowType())) .collect(Collectors.toList()); } } diff --git a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticSchemaResp.java b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticSchemaResp.java index df30a0efa..fd7c39836 100644 --- a/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticSchemaResp.java +++ b/headless/api/src/main/java/com/tencent/supersonic/headless/api/pojo/response/SemanticSchemaResp.java @@ -40,30 +40,23 @@ public class SemanticSchemaResp { } public MetricSchemaResp getMetric(String bizName) { - return metrics.stream() - .filter(metric -> bizName.equalsIgnoreCase(metric.getBizName())) - .findFirst() - .orElse(null); + return metrics.stream().filter(metric -> bizName.equalsIgnoreCase(metric.getBizName())) + .findFirst().orElse(null); } public MetricSchemaResp getMetric(Long id) { - return metrics.stream() - .filter(metric -> id.equals(metric.getId())) - .findFirst() + return metrics.stream().filter(metric -> id.equals(metric.getId())).findFirst() .orElse(null); } public DimSchemaResp getDimension(String bizName) { return dimensions.stream() - .filter(dimension -> bizName.equalsIgnoreCase(dimension.getBizName())) - .findFirst() + .filter(dimension -> bizName.equalsIgnoreCase(dimension.getBizName())).findFirst() .orElse(null); } public DimSchemaResp getDimension(Long id) { - return dimensions.stream() - .filter(dimension -> id.equals(dimension.getId())) - .findFirst() + return dimensions.stream().filter(dimension -> id.equals(dimension.getId())).findFirst() .orElse(null); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java index 1ddb40255..ba1d19607 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/ChatQueryContext.java @@ -41,14 +41,17 @@ public class ChatQueryContext { private Map> modelIdToDataSetIds; private User user; private boolean saveAnswer; - @Builder.Default private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; + @Builder.Default + private Text2SQLType text2SQLType = Text2SQLType.RULE_AND_LLM; private QueryFilters queryFilters; private List candidateQueries = new ArrayList<>(); private SchemaMapInfo mapInfo = new SchemaMapInfo(); private SemanticParseInfo contextParseInfo; private MapModeEnum mapModeEnum = MapModeEnum.STRICT; - @JsonIgnore private SemanticSchema semanticSchema; - @JsonIgnore private ChatWorkflowState chatWorkflowState; + @JsonIgnore + private SemanticSchema semanticSchema; + @JsonIgnore + private ChatWorkflowState chatWorkflowState; private QueryDataType queryDataType = QueryDataType.ALL; private ChatModelConfig modelConfig; private PromptConfig promptConfig; @@ -58,14 +61,11 @@ public class ChatQueryContext { ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); int parseShowCount = Integer.parseInt(parserConfig.getParameterValue(ParserConfig.PARSER_SHOW_COUNT)); - candidateQueries = - candidateQueries.stream() - .sorted( - Comparator.comparing( - semanticQuery -> semanticQuery.getParseInfo().getScore(), - Comparator.reverseOrder())) - .limit(parseShowCount) - .collect(Collectors.toList()); + candidateQueries = candidateQueries.stream() + .sorted(Comparator.comparing( + semanticQuery -> semanticQuery.getParseInfo().getScore(), + Comparator.reverseOrder())) + .limit(parseShowCount).collect(Collectors.toList()); return candidateQueries; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java index e605dcd85..33b9e98a5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/AggCorrector.java @@ -17,11 +17,10 @@ public class AggCorrector extends BaseSemanticCorrector { addAggregate(chatQueryContext, semanticParseInfo); } - private void addAggregate( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { - List sqlGroupByFields = - SqlSelectHelper.getGroupByFields( - semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); + private void addAggregate(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { + List sqlGroupByFields = SqlSelectHelper + .getGroupByFields(semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); if (CollectionUtils.isEmpty(sqlGroupByFields)) { return; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java index 813d781a6..9f324f8c3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/BaseSemanticCorrector.java @@ -35,20 +35,18 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return; } doCorrect(chatQueryContext, semanticParseInfo); - log.debug( - "sqlCorrection:{} sql:{}", - this.getClass().getSimpleName(), + log.debug("sqlCorrection:{} sql:{}", this.getClass().getSimpleName(), semanticParseInfo.getSqlInfo()); } catch (Exception e) { log.error(String.format("correct error,sqlInfo:%s", semanticParseInfo.getSqlInfo()), e); } } - public abstract void doCorrect( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo); + public abstract void doCorrect(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo); - protected Map getFieldNameMap( - ChatQueryContext chatQueryContext, Long dataSetId) { + protected Map getFieldNameMap(ChatQueryContext chatQueryContext, + Long dataSetId) { Map result = getFieldNameMapFromDB(chatQueryContext, dataSetId); if (chatQueryContext.containsPartitionDimensions(dataSetId)) { @@ -63,8 +61,8 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { return result; } - private static Map getFieldNameMapFromDB( - ChatQueryContext chatQueryContext, Long dataSetId) { + private static Map getFieldNameMapFromDB(ChatQueryContext chatQueryContext, + Long dataSetId) { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); List dbAllFields = new ArrayList<>(); @@ -72,51 +70,38 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { dbAllFields.addAll(semanticSchema.getDimensions()); // support fieldName and field alias - return dbAllFields.stream() - .filter(entry -> dataSetId.equals(entry.getDataSetId())) - .flatMap( - schemaElement -> { - Set elements = new HashSet<>(); - elements.add(schemaElement.getName()); - if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { - elements.addAll(schemaElement.getAlias()); - } - return elements.stream(); - }) - .collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1)); + return dbAllFields.stream().filter(entry -> dataSetId.equals(entry.getDataSetId())) + .flatMap(schemaElement -> { + Set elements = new HashSet<>(); + elements.add(schemaElement.getName()); + if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { + elements.addAll(schemaElement.getAlias()); + } + return elements.stream(); + }).collect(Collectors.toMap(a -> a, a -> a, (k1, k2) -> k1)); } - protected void addAggregateToMetric( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + protected void addAggregateToMetric(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { // add aggregate to all metric String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); Long dataSetId = semanticParseInfo.getDataSet().getDataSetId(); List metrics = getMetricElements(chatQueryContext, dataSetId); - Map metricToAggregate = - metrics.stream() - .map( - schemaElement -> { - if (Objects.isNull(schemaElement.getDefaultAgg())) { - schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); - } - return schemaElement; - }) - .flatMap( - schemaElement -> { - Set elements = new HashSet<>(); - elements.add(schemaElement.getName()); - if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { - elements.addAll(schemaElement.getAlias()); - } - return elements.stream() - .map( - element -> - Pair.of( - element, - schemaElement.getDefaultAgg())); - }) - .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); + Map metricToAggregate = metrics.stream().map(schemaElement -> { + if (Objects.isNull(schemaElement.getDefaultAgg())) { + schemaElement.setDefaultAgg(AggregateTypeEnum.SUM.name()); + } + return schemaElement; + }).flatMap(schemaElement -> { + Set elements = new HashSet<>(); + elements.add(schemaElement.getName()); + if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { + elements.addAll(schemaElement.getAlias()); + } + return elements.stream() + .map(element -> Pair.of(element, schemaElement.getDefaultAgg())); + }).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); if (CollectionUtils.isEmpty(metricToAggregate)) { return; @@ -125,39 +110,36 @@ public abstract class BaseSemanticCorrector implements SemanticCorrector { semanticParseInfo.getSqlInfo().setCorrectedS2SQL(aggregateSql); } - protected List getMetricElements( - ChatQueryContext chatQueryContext, Long dataSetId) { + protected List getMetricElements(ChatQueryContext chatQueryContext, + Long dataSetId) { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); return semanticSchema.getMetrics(dataSetId); } protected Set getDimensions(Long dataSetId, SemanticSchema semanticSchema) { Set dimensions = - semanticSchema.getDimensions(dataSetId).stream() - .flatMap( - schemaElement -> { - Set elements = new HashSet<>(); - elements.add(schemaElement.getName()); - if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { - elements.addAll(schemaElement.getAlias()); - } - return elements.stream(); - }) - .collect(Collectors.toSet()); + semanticSchema.getDimensions(dataSetId).stream().flatMap(schemaElement -> { + Set elements = new HashSet<>(); + elements.add(schemaElement.getName()); + if (!CollectionUtils.isEmpty(schemaElement.getAlias())) { + elements.addAll(schemaElement.getAlias()); + } + return elements.stream(); + }).collect(Collectors.toSet()); dimensions.add(TimeDimensionEnum.DAY.getChName()); return dimensions; } - protected boolean containsPartitionDimensions( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + protected boolean containsPartitionDimensions(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { Long dataSetId = semanticParseInfo.getDataSetId(); SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); return dataSetSchema.containsPartitionDimensions(); } - protected void removeDateIfExist( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + protected void removeDateIfExist(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); Set removeFieldNames = new HashSet<>(); removeFieldNames.addAll(TimeDimensionEnum.getChNameList()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java index 1c930bf9c..1dd74fc27 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/GroupByCorrector.java @@ -31,8 +31,8 @@ public class GroupByCorrector extends BaseSemanticCorrector { addGroupByFields(chatQueryContext, semanticParseInfo); } - private Boolean needAddGroupBy( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private Boolean needAddGroupBy(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { if (!QueryType.AGGREGATE.equals(semanticParseInfo.getQueryType())) { return false; } @@ -66,8 +66,8 @@ public class GroupByCorrector extends BaseSemanticCorrector { return true; } - private void addGroupByFields( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private void addGroupByFields(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { Long dataSetId = semanticParseInfo.getDataSetId(); // add dimension group by SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); @@ -78,19 +78,14 @@ public class GroupByCorrector extends BaseSemanticCorrector { List selectFields = SqlSelectHelper.gePureSelectFields(correctS2SQL); List aggregateFields = SqlSelectHelper.getAggregateFields(correctS2SQL); Set groupByFields = - selectFields.stream() - .filter(field -> dimensions.contains(field)) - .filter( - field -> { - if (!CollectionUtils.isEmpty(aggregateFields) - && aggregateFields.contains(field)) { - return false; - } - return true; - }) - .collect(Collectors.toSet()); - semanticParseInfo - .getSqlInfo() + selectFields.stream().filter(field -> dimensions.contains(field)).filter(field -> { + if (!CollectionUtils.isEmpty(aggregateFields) + && aggregateFields.contains(field)) { + return false; + } + return true; + }).collect(Collectors.toSet()); + semanticParseInfo.getSqlInfo() .setCorrectedS2SQL(SqlAddHelper.addGroupBy(correctS2SQL, groupByFields)); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java index 381e8231b..cac0ef368 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/HavingCorrector.java @@ -42,10 +42,8 @@ public class HavingCorrector extends BaseSemanticCorrector { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); - Set metrics = - semanticSchema.getMetrics(dataSet).stream() - .map(schemaElement -> schemaElement.getName()) - .collect(Collectors.toSet()); + Set metrics = semanticSchema.getMetrics(dataSet).stream() + .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); if (CollectionUtils.isEmpty(metrics)) { return; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java index aa0b60318..d1a6ab340 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/S2SqlDateHelper.java @@ -13,13 +13,13 @@ import java.util.Date; public class S2SqlDateHelper { - public static Pair calculateDateRange( - TimeDefaultConfig timeConfig, String timeFormat) { + public static Pair calculateDateRange(TimeDefaultConfig timeConfig, + String timeFormat) { return calculateDateRange(DateUtils.getBeforeDate(0), timeConfig, timeFormat); } - public static Pair calculateDateRange( - String currentDate, TimeDefaultConfig timeConfig, String timeFormat) { + public static Pair calculateDateRange(String currentDate, + TimeDefaultConfig timeConfig, String timeFormat) { Integer unit = timeConfig.getUnit(); if (timeConfig == null || unit == null || unit < 0) { return Pair.of(null, null); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java index 6162fb0de..918713e5d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrector.java @@ -46,8 +46,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { correctFieldName(chatQueryContext, semanticParseInfo); } - private void removeDateFields( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private void removeDateFields(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { if (containsPartitionDimensions(chatQueryContext, semanticParseInfo)) { return; } @@ -61,8 +61,8 @@ public class SchemaCorrector extends BaseSemanticCorrector { sqlInfo.setCorrectedS2SQL(sql); } - private void correctFieldName( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private void correctFieldName(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { Map fieldNameMap = getFieldNameMap(chatQueryContext, semanticParseInfo.getDataSetId()); // add as fieldName @@ -82,19 +82,13 @@ public class SchemaCorrector extends BaseSemanticCorrector { } Map> fieldValueToFieldNames = - linking.stream() - .collect( - Collectors.groupingBy( - LLMReq.ElementValue::getFieldValue, - Collectors.mapping( - LLMReq.ElementValue::getFieldName, - Collectors.toSet()))); + linking.stream().collect(Collectors.groupingBy(LLMReq.ElementValue::getFieldValue, + Collectors.mapping(LLMReq.ElementValue::getFieldName, Collectors.toSet()))); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String sql = - SqlReplaceHelper.replaceFieldNameByValue( - sqlInfo.getCorrectedS2SQL(), fieldValueToFieldNames); + String sql = SqlReplaceHelper.replaceFieldNameByValue(sqlInfo.getCorrectedS2SQL(), + fieldValueToFieldNames); sqlInfo.setCorrectedS2SQL(sql); } @@ -117,27 +111,20 @@ public class SchemaCorrector extends BaseSemanticCorrector { return; } - Map> filedNameToValueMap = - linking.stream() - .collect( - Collectors.groupingBy( - LLMReq.ElementValue::getFieldName, - Collectors.mapping( - LLMReq.ElementValue::getFieldValue, - Collectors.toMap( - oldValue -> oldValue, - newValue -> newValue, - (existingValue, newValue) -> newValue)))); + Map> filedNameToValueMap = linking.stream() + .collect(Collectors.groupingBy(LLMReq.ElementValue::getFieldName, + Collectors.mapping(LLMReq.ElementValue::getFieldValue, + Collectors.toMap(oldValue -> oldValue, newValue -> newValue, + (existingValue, newValue) -> newValue)))); SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); - String sql = - SqlReplaceHelper.replaceValue( - sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, false); + String sql = SqlReplaceHelper.replaceValue(sqlInfo.getCorrectedS2SQL(), filedNameToValueMap, + false); sqlInfo.setCorrectedS2SQL(sql); } - public void removeFilterIfNotInLinkingValue( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + public void removeFilterIfNotInLinkingValue(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { SqlInfo sqlInfo = semanticParseInfo.getSqlInfo(); String correctS2SQL = sqlInfo.getCorrectedS2SQL(); List whereExpressionList = @@ -152,37 +139,21 @@ public class SchemaCorrector extends BaseSemanticCorrector { if (CollectionUtils.isEmpty(linkingValues)) { linkingValues = new ArrayList<>(); } - Set linkingFieldNames = - linkingValues.stream() - .map(linking -> linking.getFieldName()) - .collect(Collectors.toSet()); + Set linkingFieldNames = linkingValues.stream() + .map(linking -> linking.getFieldName()).collect(Collectors.toSet()); - Set removeFieldNames = - whereExpressionList.stream() - .filter( - fieldExpression -> - StringUtils.isBlank(fieldExpression.getFunction())) - .filter( - fieldExpression -> - !TimeDimensionEnum.containsTimeDimension( - fieldExpression.getFieldName())) - .filter( - fieldExpression -> - FilterOperatorEnum.EQUALS - .getValue() - .equals(fieldExpression.getOperator())) - .filter( - fieldExpression -> - dimensions.contains(fieldExpression.getFieldName())) - .filter( - fieldExpression -> - !DateUtils.isAnyDateString( - fieldExpression.getFieldValue().toString())) - .filter( - fieldExpression -> - !linkingFieldNames.contains(fieldExpression.getFieldName())) - .map(fieldExpression -> fieldExpression.getFieldName()) - .collect(Collectors.toSet()); + Set removeFieldNames = whereExpressionList.stream() + .filter(fieldExpression -> StringUtils.isBlank(fieldExpression.getFunction())) + .filter(fieldExpression -> !TimeDimensionEnum + .containsTimeDimension(fieldExpression.getFieldName())) + .filter(fieldExpression -> FilterOperatorEnum.EQUALS.getValue() + .equals(fieldExpression.getOperator())) + .filter(fieldExpression -> dimensions.contains(fieldExpression.getFieldName())) + .filter(fieldExpression -> !DateUtils + .isAnyDateString(fieldExpression.getFieldValue().toString())) + .filter(fieldExpression -> !linkingFieldNames + .contains(fieldExpression.getFieldName())) + .map(fieldExpression -> fieldExpression.getFieldName()).collect(Collectors.toSet()); String sql = SqlRemoveHelper.removeWhereCondition(correctS2SQL, removeFieldNames); sqlInfo.setCorrectedS2SQL(sql); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java index 1582cffd9..e70d5d13c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrector.java @@ -34,8 +34,7 @@ public class SelectCorrector extends BaseSemanticCorrector { List selectFields = SqlSelectHelper.getSelectFields(correctS2SQL); // If the number of aggregated fields is equal to the number of queried fields, do not add // fields to select. - if (!CollectionUtils.isEmpty(aggregateFields) - && !CollectionUtils.isEmpty(selectFields) + if (!CollectionUtils.isEmpty(aggregateFields) && !CollectionUtils.isEmpty(selectFields) && aggregateFields.size() == selectFields.size()) { return; } @@ -43,10 +42,8 @@ public class SelectCorrector extends BaseSemanticCorrector { semanticParseInfo.getSqlInfo().setCorrectedS2SQL(correctS2SQL); } - protected String addFieldsToSelect( - ChatQueryContext chatQueryContext, - SemanticParseInfo semanticParseInfo, - String correctS2SQL) { + protected String addFieldsToSelect(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo, String correctS2SQL) { correctS2SQL = addTagDefaultFields(chatQueryContext, semanticParseInfo, correctS2SQL); Set selectFields = new HashSet<>(SqlSelectHelper.getSelectFields(correctS2SQL)); @@ -69,10 +66,8 @@ public class SelectCorrector extends BaseSemanticCorrector { return addFieldsToSelectSql; } - private String addTagDefaultFields( - ChatQueryContext chatQueryContext, - SemanticParseInfo semanticParseInfo, - String correctS2SQL) { + private String addTagDefaultFields(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo, String correctS2SQL) { // If it is in DETAIL mode and select *, add default metrics and dimensions. boolean hasAsterisk = SqlSelectFunctionHelper.hasAsterisk(correctS2SQL); if (!(hasAsterisk && QueryType.DETAIL.equals(semanticParseInfo.getQueryType()))) { @@ -84,17 +79,13 @@ public class SelectCorrector extends BaseSemanticCorrector { Set needAddDefaultFields = new HashSet<>(); if (Objects.nonNull(dataSetSchema)) { if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultMetrics())) { - Set metrics = - dataSetSchema.getTagDefaultMetrics().stream() - .map(schemaElement -> schemaElement.getName()) - .collect(Collectors.toSet()); + Set metrics = dataSetSchema.getTagDefaultMetrics().stream() + .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); needAddDefaultFields.addAll(metrics); } if (!CollectionUtils.isEmpty(dataSetSchema.getTagDefaultDimensions())) { - Set dimensions = - dataSetSchema.getTagDefaultDimensions().stream() - .map(schemaElement -> schemaElement.getName()) - .collect(Collectors.toSet()); + Set dimensions = dataSetSchema.getTagDefaultDimensions().stream() + .map(schemaElement -> schemaElement.getName()).collect(Collectors.toSet()); needAddDefaultFields.addAll(dimensions); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java index 4deeb6b49..68ff19673 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrector.java @@ -36,15 +36,14 @@ public class TimeCorrector extends BaseSemanticCorrector { } } - private void addDateIfNotExist( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private void addDateIfNotExist(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); List whereFields = SqlSelectHelper.getWhereFields(correctS2SQL); Long dataSetId = semanticParseInfo.getDataSetId(); DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap().get(dataSetId); - if (Objects.isNull(dataSetSchema) - || Objects.isNull(dataSetSchema.getPartitionDimension()) + if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension()) || Objects.isNull(dataSetSchema.getPartitionDimension().getName()) || TimeDimensionEnum.containsZhTimeDimension(whereFields)) { return; @@ -66,13 +65,8 @@ public class TimeCorrector extends BaseSemanticCorrector { correctS2SQL = SqlAddHelper.addParenthesisToWhere(correctS2SQL); String startDateLeft = dateRange.getLeft(); String endDateRight = dateRange.getRight(); - String condExpr = - String.format( - " ( %s >= '%s' and %s <= '%s' )", - partitionDimension, - startDateLeft, - partitionDimension, - endDateRight); + String condExpr = String.format(" ( %s >= '%s' and %s <= '%s' )", + partitionDimension, startDateLeft, partitionDimension, endDateRight); correctS2SQL = addConditionToSQL(correctS2SQL, condExpr); } } @@ -83,8 +77,7 @@ public class TimeCorrector extends BaseSemanticCorrector { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); DateBoundInfo dateBoundInfo = SqlDateSelectHelper.getDateBoundInfo(correctS2SQL); - if (dateBoundInfo != null - && StringUtils.isBlank(dateBoundInfo.getLowerBound()) + if (dateBoundInfo != null && StringUtils.isBlank(dateBoundInfo.getLowerBound()) && StringUtils.isNotBlank(dateBoundInfo.getUpperBound()) && StringUtils.isNotBlank(dateBoundInfo.getUpperDate())) { String upperDate = dateBoundInfo.getUpperDate(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java index 02589d82a..6bae88e39 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrector.java @@ -31,8 +31,8 @@ public class WhereCorrector extends BaseSemanticCorrector { updateFieldValueByTechName(chatQueryContext, semanticParseInfo); } - protected void addQueryFilter( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + protected void addQueryFilter(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { String queryFilter = getQueryFilter(chatQueryContext.getQueryFilters()); String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); @@ -55,8 +55,8 @@ public class WhereCorrector extends BaseSemanticCorrector { return QueryFilterParser.parse(queryFilters); } - private void updateFieldValueByTechName( - ChatQueryContext chatQueryContext, SemanticParseInfo semanticParseInfo) { + private void updateFieldValueByTechName(ChatQueryContext chatQueryContext, + SemanticParseInfo semanticParseInfo) { SemanticSchema semanticSchema = chatQueryContext.getSemanticSchema(); Long dataSetId = semanticParseInfo.getDataSetId(); List dimensions = semanticSchema.getDimensions(dataSetId); @@ -75,50 +75,25 @@ public class WhereCorrector extends BaseSemanticCorrector { private Map> getAliasAndBizNameToTechName( List dimensions) { return dimensions.stream() - .filter( - dimension -> - Objects.nonNull(dimension) - && StringUtils.isNotEmpty(dimension.getName()) - && !CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) - .collect( - Collectors.toMap( - SchemaElement::getName, - dimension -> - dimension.getSchemaValueMaps().stream() - .filter( - valueMap -> - Objects.nonNull(valueMap) - && StringUtils.isNotEmpty( - valueMap - .getTechName())) - .flatMap( - valueMap -> { - Map map = - new HashMap<>(); - if (StringUtils.isNotEmpty( - valueMap.getBizName())) { - map.put( - valueMap.getBizName(), - valueMap.getTechName()); - } - if (!CollectionUtils.isEmpty( - valueMap.getAlias())) { - valueMap.getAlias().stream() - .filter( - StringUtils - ::isNotEmpty) - .forEach( - alias -> - map.put( - alias, - valueMap - .getTechName())); - } - return map.entrySet().stream(); - }) - .collect( - Collectors.toMap( - Map.Entry::getKey, - Map.Entry::getValue)))); + .filter(dimension -> Objects.nonNull(dimension) + && StringUtils.isNotEmpty(dimension.getName()) + && !CollectionUtils.isEmpty(dimension.getSchemaValueMaps())) + .collect(Collectors.toMap(SchemaElement::getName, + dimension -> dimension.getSchemaValueMaps().stream() + .filter(valueMap -> Objects.nonNull(valueMap) + && StringUtils.isNotEmpty(valueMap.getTechName())) + .flatMap(valueMap -> { + Map map = new HashMap<>(); + if (StringUtils.isNotEmpty(valueMap.getBizName())) { + map.put(valueMap.getBizName(), valueMap.getTechName()); + } + if (!CollectionUtils.isEmpty(valueMap.getAlias())) { + valueMap.getAlias().stream().filter(StringUtils::isNotEmpty) + .forEach(alias -> map.put(alias, + valueMap.getTechName())); + } + return map.entrySet().stream(); + }).collect( + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)))); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java index aec5ddcea..36a351c26 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DatabaseMapResult.java @@ -31,10 +31,7 @@ public class DatabaseMapResult extends MapResult { @Override public String getMapKey() { - return this.getName() - + Constants.UNDERLINE - + this.getSchemaElement().getId() - + Constants.UNDERLINE - + this.getSchemaElement().getName(); + return this.getName() + Constants.UNDERLINE + this.getSchemaElement().getId() + + Constants.UNDERLINE + this.getSchemaElement().getName(); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictUpdateMode.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictUpdateMode.java index 85d2be561..81897672d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictUpdateMode.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictUpdateMode.java @@ -1,11 +1,8 @@ package com.tencent.supersonic.headless.chat.knowledge; public enum DictUpdateMode { - OFFLINE_FULL("OFFLINE_FULL"), - OFFLINE_MODEL("OFFLINE_MODEL"), - REALTIME_ADD("REALTIME_ADD"), - REALTIME_DELETE("REALTIME_DELETE"), - NOT_SUPPORT("NOT_SUPPORT"); + OFFLINE_FULL("OFFLINE_FULL"), OFFLINE_MODEL("OFFLINE_MODEL"), REALTIME_ADD( + "REALTIME_ADD"), REALTIME_DELETE("REALTIME_DELETE"), NOT_SUPPORT("NOT_SUPPORT"); private String value; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictionaryAttributeUtil.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictionaryAttributeUtil.java index dcae4b7f0..202a47016 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictionaryAttributeUtil.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/DictionaryAttributeUtil.java @@ -16,49 +16,36 @@ import java.util.stream.IntStream; /** Dictionary Attribute Util */ public class DictionaryAttributeUtil { - public static CoreDictionary.Attribute getAttribute( - CoreDictionary.Attribute old, CoreDictionary.Attribute add) { + public static CoreDictionary.Attribute getAttribute(CoreDictionary.Attribute old, + CoreDictionary.Attribute add) { Map map = new HashMap<>(); Map originalMap = new HashMap<>(); - IntStream.range(0, old.nature.length) - .boxed() - .forEach( - i -> { - map.put(old.nature[i], old.frequency[i]); - if (Objects.nonNull(old.originals)) { - originalMap.put(old.nature[i], old.originals[i]); - } - }); - IntStream.range(0, add.nature.length) - .boxed() - .forEach( - i -> { - map.put(add.nature[i], add.frequency[i]); - if (Objects.nonNull(add.originals)) { - originalMap.put(add.nature[i], add.originals[i]); - } - }); + IntStream.range(0, old.nature.length).boxed().forEach(i -> { + map.put(old.nature[i], old.frequency[i]); + if (Objects.nonNull(old.originals)) { + originalMap.put(old.nature[i], old.originals[i]); + } + }); + IntStream.range(0, add.nature.length).boxed().forEach(i -> { + map.put(add.nature[i], add.frequency[i]); + if (Objects.nonNull(add.originals)) { + originalMap.put(add.nature[i], add.originals[i]); + } + }); List> list = new LinkedList>(map.entrySet()); - Collections.sort( - list, - new Comparator>() { - public int compare( - Map.Entry o1, Map.Entry o2) { - return o2.getValue() - o1.getValue(); - } - }); + Collections.sort(list, new Comparator>() { + public int compare(Map.Entry o1, Map.Entry o2) { + return o2.getValue() - o1.getValue(); + } + }); String[] originals = list.stream().map(l -> originalMap.get(l.getKey())).toArray(String[]::new); - CoreDictionary.Attribute attribute = - new CoreDictionary.Attribute( - list.stream() - .map(i -> i.getKey()) - .collect(Collectors.toList()) - .toArray(new Nature[0]), - list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(), - originals, - list.stream().map(i -> i.getValue()).findFirst().get()); + CoreDictionary.Attribute attribute = new CoreDictionary.Attribute( + list.stream().map(i -> i.getKey()).collect(Collectors.toList()) + .toArray(new Nature[0]), + list.stream().map(i -> i.getValue()).mapToInt(Integer::intValue).toArray(), + originals, list.stream().map(i -> i.getValue()).findFirst().get()); return attribute; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java index 612dc3723..8f159e565 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/HanlpMapResult.java @@ -43,8 +43,7 @@ public class HanlpMapResult extends MapResult { @Override public String getMapKey() { - return this.getName() - + Constants.UNDERLINE + return this.getName() + Constants.UNDERLINE + String.join(Constants.UNDERLINE, this.getNatures()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/KnowledgeBaseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/KnowledgeBaseService.java index adac95426..a6dabb953 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/KnowledgeBaseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/KnowledgeBaseService.java @@ -17,25 +17,17 @@ public class KnowledgeBaseService { public void updateSemanticKnowledge(List natures) { - List prefixes = - natures.stream() - .filter( - entry -> - !entry.getNatureWithFrequency() - .contains(DictWordType.SUFFIX.getType())) - .collect(Collectors.toList()); + List prefixes = natures.stream().filter( + entry -> !entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType())) + .collect(Collectors.toList()); for (DictWord nature : prefixes) { HanlpHelper.addToCustomDictionary(nature); } - List suffixes = - natures.stream() - .filter( - entry -> - entry.getNatureWithFrequency() - .contains(DictWordType.SUFFIX.getType())) - .collect(Collectors.toList()); + List suffixes = natures.stream().filter( + entry -> entry.getNatureWithFrequency().contains(DictWordType.SUFFIX.getType())) + .collect(Collectors.toList()); SearchService.loadSuffix(suffixes); } @@ -64,35 +56,23 @@ public class KnowledgeBaseService { return HanlpHelper.getTerms(text, modelIdToDataSetIds); } - public List prefixSearch( - String key, - int limit, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public List prefixSearch(String key, int limit, + Map> modelIdToDataSetIds, Set detectDataSetIds) { return prefixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds); } - public List prefixSearchByModel( - String key, - int limit, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public List prefixSearchByModel(String key, int limit, + Map> modelIdToDataSetIds, Set detectDataSetIds) { return SearchService.prefixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds); } - public List suffixSearch( - String key, - int limit, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public List suffixSearch(String key, int limit, + Map> modelIdToDataSetIds, Set detectDataSetIds) { return suffixSearchByModel(key, limit, modelIdToDataSetIds, detectDataSetIds); } - public List suffixSearchByModel( - String key, - int limit, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public List suffixSearchByModel(String key, int limit, + Map> modelIdToDataSetIds, Set detectDataSetIds) { return SearchService.suffixSearch(key, limit, modelIdToDataSetIds, detectDataSetIds); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java index fe98b7637..e4f48597b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MetaEmbeddingService.java @@ -26,23 +26,20 @@ import java.util.stream.Stream; @Slf4j public class MetaEmbeddingService { - @Autowired private EmbeddingService embeddingService; - @Autowired private EmbeddingConfig embeddingConfig; + @Autowired + private EmbeddingService embeddingService; + @Autowired + private EmbeddingConfig embeddingConfig; - public List retrieveQuery( - RetrieveQuery retrieveQuery, - int num, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public List retrieveQuery(RetrieveQuery retrieveQuery, int num, + Map> modelIdToDataSetIds, Set detectDataSetIds) { // dataSetIds->modelIds Set allModels = NatureHelper.getModelIds(modelIdToDataSetIds, detectDataSetIds); if (CollectionUtils.isNotEmpty(allModels)) { Map filterCondition = new HashMap<>(); - filterCondition.put( - "modelId", - allModels.stream() - .map(modelId -> modelId + DictWordType.NATURE_SPILT) + filterCondition.put("modelId", + allModels.stream().map(modelId -> modelId + DictWordType.NATURE_SPILT) .collect(Collectors.toList())); retrieveQuery.setFilterCondition(filterCondition); } @@ -67,36 +64,22 @@ public class MetaEmbeddingService { return result; } // Process each Retrieval object. - List updatedRetrievals = - retrievals.stream() - .flatMap( - retrieval -> { - Long modelId = - Retrieval.getLongId( - retrieval.getMetadata().get("modelId")); - List dataSetIds = modelIdToDataSetIds.get(modelId); + List updatedRetrievals = retrievals.stream().flatMap(retrieval -> { + Long modelId = Retrieval.getLongId(retrieval.getMetadata().get("modelId")); + List dataSetIds = modelIdToDataSetIds.get(modelId); - if (CollectionUtils.isEmpty(dataSetIds)) { - return Stream.of(retrieval); - } + if (CollectionUtils.isEmpty(dataSetIds)) { + return Stream.of(retrieval); + } - return dataSetIds.stream() - .map( - dataSetId -> { - Retrieval newRetrieval = new Retrieval(); - BeanUtils.copyProperties( - retrieval, newRetrieval); - newRetrieval - .getMetadata() - .putIfAbsent( - "dataSetId", - dataSetId - + Constants - .UNDERLINE); - return newRetrieval; - }); - }) - .collect(Collectors.toList()); + return dataSetIds.stream().map(dataSetId -> { + Retrieval newRetrieval = new Retrieval(); + BeanUtils.copyProperties(retrieval, newRetrieval); + newRetrieval.getMetadata().putIfAbsent("dataSetId", + dataSetId + Constants.UNDERLINE); + return newRetrieval; + }); + }).collect(Collectors.toList()); result.setRetrieval(updatedRetrievals); return result; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MultiCustomDictionary.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MultiCustomDictionary.java index bb31c09f2..727aa5236 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MultiCustomDictionary.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/MultiCustomDictionary.java @@ -60,12 +60,9 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { * @param addToSuggeterTrie * @return */ - public static boolean load( - String path, - Nature defaultNature, + public static boolean load(String path, Nature defaultNature, TreeMap map, - LinkedHashSet customNatureCollector, - boolean addToSuggeterTrie) { + LinkedHashSet customNatureCollector, boolean addToSuggeterTrie) { try { String splitter = "\\s"; if (path.endsWith(".csv")) { @@ -112,9 +109,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { attribute = new CoreDictionary.Attribute(natureCount); for (int i = 0; i < natureCount; ++i) { - attribute.nature[i] = - LexiconUtility.convertStringToNature( - param[1 + 2 * i], customNatureCollector); + attribute.nature[i] = LexiconUtility.convertStringToNature(param[1 + 2 * i], + customNatureCollector); attribute.frequency[i] = Integer.parseInt(param[2 + 2 * i]); attribute.originals[i] = original; attribute.totalFrequency += attribute.frequency[i]; @@ -133,10 +129,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { Nature nature = attribute.nature[i]; PriorityQueue priorityQueue = NATURE_TO_VALUES.get(nature.toString()); if (Objects.isNull(priorityQueue)) { - priorityQueue = - new PriorityQueue<>( - MAX_SIZE, - Comparator.comparingInt(Term::getFrequency).reversed()); + priorityQueue = new PriorityQueue<>(MAX_SIZE, + Comparator.comparingInt(Term::getFrequency).reversed()); NATURE_TO_VALUES.put(nature.toString(), priorityQueue); } Term term = new Term(word, nature); @@ -159,12 +153,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { logger.warning("自定义词典" + Arrays.toString(path) + "加载失败"); return false; } else { - logger.info( - "自定义词典加载成功:" - + this.dat.size() - + "个词条,耗时" - + (System.currentTimeMillis() - start) - + "ms"); + logger.info("自定义词典加载成功:" + this.dat.size() + "个词条,耗时" + + (System.currentTimeMillis() - start) + "ms"); this.path = path; return true; } @@ -180,11 +170,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { * @param addToSuggestTrie * @return */ - public static boolean loadMainDictionary( - String mainPath, - String[] path, - DoubleArrayTrie dat, - boolean isCache, + public static boolean loadMainDictionary(String mainPath, String[] path, + DoubleArrayTrie dat, boolean isCache, boolean addToSuggestTrie) { logger.info("自定义词典开始加载:" + mainPath); if (loadDat(mainPath, dat)) { @@ -204,9 +191,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { p = file.getParent() + File.separator + fileName.substring(0, cut); try { - defaultNature = - LexiconUtility.convertStringToNature( - nature, customNatureCollector); + defaultNature = LexiconUtility.convertStringToNature(nature, + customNatureCollector); } catch (Exception var16) { logger.severe("配置文件【" + p + "】写错了!" + var16); continue; @@ -241,10 +227,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { attributeList.add(entry.getValue()); } - DataOutputStream out = - new DataOutputStream( - new BufferedOutputStream( - IOUtil.newOutputStream(mainPath + ".bin"))); + DataOutputStream out = new DataOutputStream( + new BufferedOutputStream(IOUtil.newOutputStream(mainPath + ".bin"))); if (customNatureCollector.isEmpty()) { for (int i = Nature.begin.ordinal() + 1; i < Nature.values().length; ++i) { Nature nature = Nature.values()[i]; @@ -287,8 +271,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { return loadDat(path, HanLP.Config.CustomDictionaryPath, dat); } - public static boolean loadDat( - String path, String[] customDicPath, DoubleArrayTrie dat) { + public static boolean loadDat(String path, String[] customDicPath, + DoubleArrayTrie dat) { try { if (HanLP.Config.CustomDictionaryAutoRefreshCache && DynamicCustomDictionary.isDicNeedUpdate(path, customDicPath)) { @@ -374,8 +358,8 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { IOUtil.deleteFile(this.path[0] + ".bin"); Boolean loadCacheOk = this.loadDat(this.path[0], this.path, this.dat); if (!loadCacheOk) { - return this.loadMainDictionary( - this.path[0], this.path, this.dat, true, addToSuggesterTrie); + return this.loadMainDictionary(this.path[0], this.path, this.dat, true, + addToSuggesterTrie); } } return false; @@ -389,8 +373,7 @@ public class MultiCustomDictionary extends DynamicCustomDictionary { word = CharTable.convert(word); } CoreDictionary.Attribute att = - natureWithFrequency == null - ? new CoreDictionary.Attribute(Nature.nz, 1) + natureWithFrequency == null ? new CoreDictionary.Attribute(Nature.nz, 1) : CoreDictionary.Attribute.create(natureWithFrequency); boolean isLetters = isLetters(word); word = getWordBySpace(word); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java index 9c688fc99..b029650c3 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/SearchService.java @@ -43,35 +43,23 @@ public class SearchService { * @param key * @return */ - public static List prefixSearch( - String key, - int limit, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public static List prefixSearch(String key, int limit, + Map> modelIdToDataSetIds, Set detectDataSetIds) { return prefixSearch(key, limit, trie, modelIdToDataSetIds, detectDataSetIds); } - public static List prefixSearch( - String key, - int limit, - BinTrie> binTrie, - Map> modelIdToDataSetIds, + public static List prefixSearch(String key, int limit, + BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { Set>> result = search(key, binTrie); - List hanlpMapResults = - result.stream() - .map( - entry -> { - String name = entry.getKey().replace("#", " "); - double similarity = EditDistanceUtils.getSimilarity(name, key); - return new HanlpMapResult( - name, entry.getValue(), key, similarity); - }) - .sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .collect(Collectors.toList()); - hanlpMapResults = - transformAndFilterByDataSet( - hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit); + List hanlpMapResults = result.stream().map(entry -> { + String name = entry.getKey().replace("#", " "); + double similarity = EditDistanceUtils.getSimilarity(name, key); + return new HanlpMapResult(name, entry.getValue(), key, similarity); + }).sorted((a, b) -> -(b.getName().length() - a.getName().length())) + .collect(Collectors.toList()); + hanlpMapResults = transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, + detectDataSetIds, limit); return hanlpMapResults; } @@ -81,87 +69,55 @@ public class SearchService { * @param key * @return */ - public static List suffixSearch( - String key, - int limit, - Map> modelIdToDataSetIds, - Set detectDataSetIds) { + public static List suffixSearch(String key, int limit, + Map> modelIdToDataSetIds, Set detectDataSetIds) { String reverseDetectSegment = StringUtils.reverse(key); - return suffixSearch( - reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, detectDataSetIds); + return suffixSearch(reverseDetectSegment, limit, suffixTrie, modelIdToDataSetIds, + detectDataSetIds); } - public static List suffixSearch( - String key, - int limit, - BinTrie> binTrie, - Map> modelIdToDataSetIds, + public static List suffixSearch(String key, int limit, + BinTrie> binTrie, Map> modelIdToDataSetIds, Set detectDataSetIds) { Set>> result = search(key, binTrie); - List hanlpMapResults = - result.stream() - .map( - entry -> { - String name = entry.getKey().replace("#", " "); - List natures = - entry.getValue().stream() - .map( - nature -> - nature.replaceAll( - DictWordType.SUFFIX - .getType(), - "")) - .collect(Collectors.toList()); + List hanlpMapResults = result.stream().map(entry -> { + String name = entry.getKey().replace("#", " "); + List natures = entry.getValue().stream() + .map(nature -> nature.replaceAll(DictWordType.SUFFIX.getType(), "")) + .collect(Collectors.toList()); - name = StringUtils.reverse(name); - double similarity = EditDistanceUtils.getSimilarity(name, key); - return new HanlpMapResult(name, natures, key, similarity); - }) - .sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .collect(Collectors.toList()); - return transformAndFilterByDataSet( - hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, limit); + name = StringUtils.reverse(name); + double similarity = EditDistanceUtils.getSimilarity(name, key); + return new HanlpMapResult(name, natures, key, similarity); + }).sorted((a, b) -> -(b.getName().length() - a.getName().length())) + .collect(Collectors.toList()); + return transformAndFilterByDataSet(hanlpMapResults, modelIdToDataSetIds, detectDataSetIds, + limit); } private static List transformAndFilterByDataSet( - List hanlpMapResults, - Map> modelIdToDataSetIds, - Set detectDataSetIds, - int limit) { - return hanlpMapResults.stream() - .peek( - hanlpMapResult -> { - List natures = - hanlpMapResult.getNatures().stream() - .map( - nature -> - NatureHelper.changeModel2DataSet( - nature, modelIdToDataSetIds)) - .flatMap(Collection::stream) - .filter( - nature -> { - if (CollectionUtils.isEmpty( - detectDataSetIds)) { - return true; - } - Long dataSetId = - NatureHelper.getDataSetId(nature); - if (dataSetId != null) { - return detectDataSetIds.contains( - dataSetId); - } - return false; - }) - .collect(Collectors.toList()); - hanlpMapResult.setNatures(natures); - }) - .filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures())) - .limit(limit) - .collect(Collectors.toList()); + List hanlpMapResults, Map> modelIdToDataSetIds, + Set detectDataSetIds, int limit) { + return hanlpMapResults.stream().peek(hanlpMapResult -> { + List natures = hanlpMapResult.getNatures().stream() + .map(nature -> NatureHelper.changeModel2DataSet(nature, modelIdToDataSetIds)) + .flatMap(Collection::stream).filter(nature -> { + if (CollectionUtils.isEmpty(detectDataSetIds)) { + return true; + } + Long dataSetId = NatureHelper.getDataSetId(nature); + if (dataSetId != null) { + return detectDataSetIds.contains(dataSetId); + } + return false; + }).collect(Collectors.toList()); + hanlpMapResult.setNatures(natures); + }).filter(hanlpMapResult -> !CollectionUtils.isEmpty(hanlpMapResult.getNatures())) + .limit(limit).collect(Collectors.toList()); } - private static Set>> search( - String key, BinTrie> binTrie) { + private static Set>> search(String key, + BinTrie> binTrie) { key = key.toLowerCase(); Set>> entrySet = new TreeSet>>(); @@ -202,14 +158,12 @@ public class SearchService { } TreeMap map = new TreeMap(); for (DictWord suffix : suffixes) { - CoreDictionary.Attribute attributeNew = - suffix.getNatureWithFrequency() == null - ? new CoreDictionary.Attribute(Nature.nz, 1) - : CoreDictionary.Attribute.create(suffix.getNatureWithFrequency()); + CoreDictionary.Attribute attributeNew = suffix.getNatureWithFrequency() == null + ? new CoreDictionary.Attribute(Nature.nz, 1) + : CoreDictionary.Attribute.create(suffix.getNatureWithFrequency()); if (map.containsKey(suffix.getWord())) { - attributeNew = - DictionaryAttributeUtil.getAttribute( - map.get(suffix.getWord()), attributeNew); + attributeNew = DictionaryAttributeUtil.getAttribute(map.get(suffix.getWord()), + attributeNew); } map.put(suffix.getWord(), attributeNew); } @@ -239,11 +193,8 @@ public class SearchService { } public static List getDimensionValue(DimensionValueReq dimensionValueReq) { - String nature = - DictWordType.NATURE_SPILT - + dimensionValueReq.getModelId() - + DictWordType.NATURE_SPILT - + dimensionValueReq.getElementID(); + String nature = DictWordType.NATURE_SPILT + dimensionValueReq.getModelId() + + DictWordType.NATURE_SPILT + dimensionValueReq.getElementID(); PriorityQueue terms = MultiCustomDictionary.NATURE_TO_VALUES.get(nature); if (CollectionUtils.isEmpty(terms)) { return new ArrayList<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/BaseWordWithAliasBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/BaseWordWithAliasBuilder.java index 67c2b7066..35d5c9453 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/BaseWordWithAliasBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/BaseWordWithAliasBuilder.java @@ -9,8 +9,8 @@ import java.util.List; public abstract class BaseWordWithAliasBuilder extends BaseWordBuilder { - public abstract DictWord getOneWordNature( - String word, SchemaElement schemaElement, boolean isSuffix); + public abstract DictWord getOneWordNature(String word, SchemaElement schemaElement, + boolean isSuffix); public List getOneWordNatureAlias(SchemaElement schemaElement, boolean isSuffix) { List dictWords = new ArrayList<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/DimensionWordBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/DimensionWordBuilder.java index 598bb3dfe..ea84c3ae2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/DimensionWordBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/DimensionWordBuilder.java @@ -29,20 +29,12 @@ public class DimensionWordBuilder extends BaseWordWithAliasBuilder { DictWord dictWord = new DictWord(); dictWord.setWord(word); Long modelId = schemaElement.getModel(); - String nature = - DictWordType.NATURE_SPILT - + modelId - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.DIMENSION.getType(); + String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + + schemaElement.getId() + DictWordType.DIMENSION.getType(); if (isSuffix) { - nature = - DictWordType.NATURE_SPILT - + modelId - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.SUFFIX.getType() - + DictWordType.DIMENSION.getType(); + nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + + schemaElement.getId() + DictWordType.SUFFIX.getType() + + DictWordType.DIMENSION.getType(); } dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); return dictWord; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/EntityWordBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/EntityWordBuilder.java index bde4f6c64..d4eb900ca 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/EntityWordBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/EntityWordBuilder.java @@ -27,12 +27,8 @@ public class EntityWordBuilder extends BaseWordWithAliasBuilder { @Override public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) { - String nature = - DictWordType.NATURE_SPILT - + schemaElement.getModel() - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.ENTITY.getType(); + String nature = DictWordType.NATURE_SPILT + schemaElement.getModel() + + DictWordType.NATURE_SPILT + schemaElement.getId() + DictWordType.ENTITY.getType(); DictWord dictWord = new DictWord(); dictWord.setWord(word); dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY * 2, nature)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/MetricWordBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/MetricWordBuilder.java index 002238a96..54a9d6d1c 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/MetricWordBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/MetricWordBuilder.java @@ -29,20 +29,12 @@ public class MetricWordBuilder extends BaseWordWithAliasBuilder { DictWord dictWord = new DictWord(); dictWord.setWord(word); Long modelId = schemaElement.getModel(); - String nature = - DictWordType.NATURE_SPILT - + modelId - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.METRIC.getType(); + String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + + schemaElement.getId() + DictWordType.METRIC.getType(); if (isSuffix) { - nature = - DictWordType.NATURE_SPILT - + modelId - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.SUFFIX.getType() - + DictWordType.METRIC.getType(); + nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + + schemaElement.getId() + DictWordType.SUFFIX.getType() + + DictWordType.METRIC.getType(); } dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); return dictWord; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/TermWordBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/TermWordBuilder.java index adc7ce18d..5dfd7d64d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/TermWordBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/TermWordBuilder.java @@ -29,20 +29,12 @@ public class TermWordBuilder extends BaseWordWithAliasBuilder { DictWord dictWord = new DictWord(); dictWord.setWord(word); Long dataSet = schemaElement.getDataSetId(); - String nature = - DictWordType.NATURE_SPILT - + dataSet - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.TERM.getType(); + String nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + + schemaElement.getId() + DictWordType.TERM.getType(); if (isSuffix) { - nature = - DictWordType.NATURE_SPILT - + dataSet - + DictWordType.NATURE_SPILT - + schemaElement.getId() - + DictWordType.SUFFIX.getType() - + DictWordType.TERM.getType(); + nature = DictWordType.NATURE_SPILT + dataSet + DictWordType.NATURE_SPILT + + schemaElement.getId() + DictWordType.SUFFIX.getType() + + DictWordType.TERM.getType(); } dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); return dictWord; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/ValueWordBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/ValueWordBuilder.java index f316589d6..8a4dd9efc 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/ValueWordBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/builder/ValueWordBuilder.java @@ -26,11 +26,8 @@ public class ValueWordBuilder extends BaseWordWithAliasBuilder { public DictWord getOneWordNature(String word, SchemaElement schemaElement, boolean isSuffix) { DictWord dictWord = new DictWord(); Long modelId = schemaElement.getModel(); - String nature = - DictWordType.NATURE_SPILT - + modelId - + DictWordType.NATURE_SPILT - + schemaElement.getId(); + String nature = DictWordType.NATURE_SPILT + modelId + DictWordType.NATURE_SPILT + + schemaElement.getId(); dictWord.setNatureWithFrequency(String.format("%s " + DEFAULT_FREQUENCY, nature)); dictWord.setWord(word); return dictWord; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/file/FileHandlerImpl.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/file/FileHandlerImpl.java index 2172d4452..7ea1212bd 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/file/FileHandlerImpl.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/file/FileHandlerImpl.java @@ -81,12 +81,8 @@ public class FileHandlerImpl implements FileHandler { String filePath = localFileConfig.getDictDirectoryLatest() + FILE_SPILT + fileName; Long fileLineNum = getFileLineNum(filePath); Integer startLine = (dictValueReq.getCurrent() - 1) * dictValueReq.getPageSize() + 1; - Integer endLine = - Integer.valueOf( - Math.min( - dictValueReq.getCurrent() * dictValueReq.getPageSize(), - fileLineNum) - + ""); + Integer endLine = Integer.valueOf( + Math.min(dictValueReq.getCurrent() * dictValueReq.getPageSize(), fileLineNum) + ""); List dictValueRespList = getFileData(filePath, startLine, endLine); dictValueRespPageInfo.setPageSize(dictValueReq.getPageSize()); @@ -112,12 +108,9 @@ public class FileHandlerImpl implements FileHandler { List fileData = new ArrayList<>(); try (Stream lines = Files.lines(Paths.get(filePath))) { - fileData = - lines.skip(startLine - 1) - .limit(endLine - startLine + 1) - .map(lineStr -> convert2Resp(lineStr)) - .filter(line -> Objects.nonNull(line)) - .collect(Collectors.toList()); + fileData = lines.skip(startLine - 1).limit(endLine - startLine + 1) + .map(lineStr -> convert2Resp(lineStr)).filter(line -> Objects.nonNull(line)) + .collect(Collectors.toList()); } catch (IOException e) { log.warn("[getFileData] e:{}", e); } @@ -204,8 +197,8 @@ public class FileHandlerImpl implements FileHandler { private BufferedWriter getWriter(String filePath, Boolean append) throws IOException { if (append) { - return Files.newBufferedWriter( - Paths.get(filePath), StandardCharsets.UTF_8, StandardOpenOption.APPEND); + return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8, + StandardOpenOption.APPEND); } return Files.newBufferedWriter(Paths.get(filePath), StandardCharsets.UTF_8); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/FileHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/FileHelper.java index 3b815543b..b42669056 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/FileHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/FileHelper.java @@ -32,17 +32,15 @@ public class FileHelper { } private static File[] getFileList(File customFolder, String suffix) { - File[] customSubFiles = - customFolder.listFiles( - file -> { - if (file.isDirectory()) { - return false; - } - if (file.getName().toLowerCase().endsWith(suffix)) { - return true; - } - return false; - }); + File[] customSubFiles = customFolder.listFiles(file -> { + if (file.isDirectory()) { + return false; + } + if (file.getName().toLowerCase().endsWith(suffix)) { + return true; + } + return false; + }); return customSubFiles; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java index ee73409c3..a3e1ae9b0 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/HanlpHelper.java @@ -57,21 +57,14 @@ public class HanlpHelper { if (segment == null) { synchronized (HanlpHelper.class) { if (segment == null) { - segment = - HanLP.newSegment() - .enableIndexMode(true) - .enableIndexMode(4) - .enableCustomDictionary(true) - .enableCustomDictionaryForcing(true) - .enableOffset(true) - .enableJapaneseNameRecognize(false) - .enableNameRecognize(false) - .enableAllNamedEntityRecognize(false) - .enableJapaneseNameRecognize(false) - .enableNumberQuantifierRecognize(false) - .enablePlaceRecognize(false) - .enableOrganizationRecognize(false) - .enableCustomDictionary(getDynamicCustomDictionary()); + segment = HanLP.newSegment().enableIndexMode(true).enableIndexMode(4) + .enableCustomDictionary(true).enableCustomDictionaryForcing(true) + .enableOffset(true).enableJapaneseNameRecognize(false) + .enableNameRecognize(false).enableAllNamedEntityRecognize(false) + .enableJapaneseNameRecognize(false) + .enableNumberQuantifierRecognize(false).enablePlaceRecognize(false) + .enableOrganizationRecognize(false) + .enableCustomDictionary(getDynamicCustomDictionary()); } } } @@ -112,8 +105,7 @@ public class HanlpHelper { boolean reload = getDynamicCustomDictionary().reload(); if (reload) { - log.info( - "Custom dictionary has been reloaded in {} milliseconds", + log.info("Custom dictionary has been reloaded in {} milliseconds", System.currentTimeMillis() - startTime); } return reload; @@ -125,21 +117,15 @@ public class HanlpHelper { } String hanlpPropertiesPath = getHanlpPropertiesPath(); - HanLP.Config.CustomDictionaryPath = - Arrays.stream(HanLP.Config.CustomDictionaryPath) - .map(path -> hanlpPropertiesPath + FILE_SPILT + path) - .toArray(String[]::new); - log.info( - "hanlpPropertiesPath:{},CustomDictionaryPath:{}", - hanlpPropertiesPath, + HanLP.Config.CustomDictionaryPath = Arrays.stream(HanLP.Config.CustomDictionaryPath) + .map(path -> hanlpPropertiesPath + FILE_SPILT + path).toArray(String[]::new); + log.info("hanlpPropertiesPath:{},CustomDictionaryPath:{}", hanlpPropertiesPath, HanLP.Config.CustomDictionaryPath); HanLP.Config.CoreDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath; - HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath = - hanlpPropertiesPath - + FILE_SPILT - + HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath; + HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath = hanlpPropertiesPath + FILE_SPILT + + HanLP.Config.CoreDictionaryTransformMatrixDictionaryPath; HanLP.Config.BiGramDictionaryPath = hanlpPropertiesPath + FILE_SPILT + HanLP.Config.BiGramDictionaryPath; HanLP.Config.CoreStopWordDictionaryPath = @@ -201,8 +187,8 @@ public class HanlpHelper { public static boolean addToCustomDictionary(DictWord dictWord) { log.debug("dictWord:{}", dictWord); - return getDynamicCustomDictionary() - .insert(dictWord.getWord(), dictWord.getNatureWithFrequency()); + return getDynamicCustomDictionary().insert(dictWord.getWord(), + dictWord.getNatureWithFrequency()); } public static void removeFromCustomDictionary(DictWord dictWord) { @@ -226,8 +212,8 @@ public class HanlpHelper { int len = natureWithFrequency.length(); log.info("filtered natureWithFrequency:{}", natureWithFrequency); if (StringUtils.isNotBlank(natureWithFrequency)) { - getDynamicCustomDictionary() - .add(dictWord.getWord(), natureWithFrequency.substring(0, len - 1)); + getDynamicCustomDictionary().add(dictWord.getWord(), + natureWithFrequency.substring(0, len - 1)); } SearchService.remove(dictWord, natureList.toArray(new Nature[0])); } @@ -257,8 +243,8 @@ public class HanlpHelper { mapResults.addAll(newResults); } - public static boolean addLetterOriginal( - List mapResults, T mapResult, CoreDictionary.Attribute attribute) { + public static boolean addLetterOriginal(List mapResults, T mapResult, + CoreDictionary.Attribute attribute) { if (attribute == null) { return false; } @@ -268,12 +254,8 @@ public class HanlpHelper { for (String nature : hanlpMapResult.getNatures()) { String orig = attribute.getOriginal(Nature.fromString(nature)); if (orig != null) { - MapResult addMapResult = - new HanlpMapResult( - orig, - Arrays.asList(nature), - hanlpMapResult.getDetectWord(), - hanlpMapResult.getSimilarity()); + MapResult addMapResult = new HanlpMapResult(orig, Arrays.asList(nature), + hanlpMapResult.getDetectWord(), hanlpMapResult.getSimilarity()); mapResults.add((T) addMapResult); isAdd = true; } @@ -317,38 +299,30 @@ public class HanlpHelper { return getSegment().seg(text.toLowerCase()).stream() .filter(term -> term.getNature().startsWith(DictWordType.NATURE_SPILT)) .map(term -> transform2ApiTerm(term, modelIdToDataSetIds)) - .flatMap(Collection::stream) - .collect(Collectors.toList()); + .flatMap(Collection::stream).collect(Collectors.toList()); } public static List getTerms(List terms, Set dataSetIds) { logTerms(terms); if (!CollectionUtils.isEmpty(dataSetIds)) { - terms = - terms.stream() - .filter( - term -> { - Long dataSetId = - NatureHelper.getDataSetId( - term.getNature().toString()); - if (Objects.nonNull(dataSetId)) { - return dataSetIds.contains(dataSetId); - } - return false; - }) - .collect(Collectors.toList()); + terms = terms.stream().filter(term -> { + Long dataSetId = NatureHelper.getDataSetId(term.getNature().toString()); + if (Objects.nonNull(dataSetId)) { + return dataSetIds.contains(dataSetId); + } + return false; + }).collect(Collectors.toList()); log.debug("terms filter by dataSetId:{}", dataSetIds); logTerms(terms); } return terms; } - public static List transform2ApiTerm( - Term term, Map> modelIdToDataSetIds) { + public static List transform2ApiTerm(Term term, + Map> modelIdToDataSetIds) { List s2Terms = Lists.newArrayList(); - List natures = - NatureHelper.changeModel2DataSet( - String.valueOf(term.getNature()), modelIdToDataSetIds); + List natures = NatureHelper.changeModel2DataSet(String.valueOf(term.getNature()), + modelIdToDataSetIds); for (String nature : natures) { S2Term s2Term = new S2Term(); BeanUtils.copyProperties(term, s2Term); @@ -364,10 +338,7 @@ public class HanlpHelper { return; } for (S2Term term : terms) { - log.debug( - "word:{},nature:{},frequency:{}", - term.word, - term.nature.toString(), + log.debug("word:{},nature:{},frequency:{}", term.word, term.nature.toString(), term.getFrequency()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java index 7c4b5ce39..371651f03 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/knowledge/helper/NatureHelper.java @@ -89,8 +89,8 @@ public class NatureHelper { return null; } - public static List changeModel2DataSet( - String nature, Map> modelIdToDataSetIds) { + public static List changeModel2DataSet(String nature, + Map> modelIdToDataSetIds) { if (SchemaElementType.TERM.equals(NatureHelper.convertToElementType(nature))) { return Collections.singletonList(nature); } @@ -99,77 +99,56 @@ public class NatureHelper { if (CollectionUtils.isEmpty(dataSetIds)) { return Collections.emptyList(); } - return dataSetIds.stream() - .map(dataSetId -> changeModel2DataSet(nature, dataSetId)) - .filter(Objects::nonNull) - .map(String::valueOf) - .collect(Collectors.toList()); + return dataSetIds.stream().map(dataSetId -> changeModel2DataSet(nature, dataSetId)) + .filter(Objects::nonNull).map(String::valueOf).collect(Collectors.toList()); } public static boolean isDimensionValueDataSetId(String nature) { return isNatureValid(nature) - && !isNatureType( - nature, DictWordType.METRIC, DictWordType.DIMENSION, DictWordType.TERM) + && !isNatureType(nature, DictWordType.METRIC, DictWordType.DIMENSION, + DictWordType.TERM) && StringUtils.isNumeric(nature.split(DictWordType.NATURE_SPILT)[1]); } public static DataSetInfoStat getDataSetStat(List terms) { - return DataSetInfoStat.builder() - .dataSetCount(getDataSetCount(terms)) + return DataSetInfoStat.builder().dataSetCount(getDataSetCount(terms)) .dimensionDataSetCount(getDimensionCount(terms)) .metricDataSetCount(getMetricCount(terms)) - .dimensionValueDataSetCount(getDimensionValueCount(terms)) - .build(); + .dimensionValueDataSetCount(getDimensionValueCount(terms)).build(); } private static long getDataSetCount(List terms) { return terms.stream() - .filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature))) - .count(); + .filter(term -> isDataSetOrEntity(term, getDataSetByNature(term.nature))).count(); } private static long getDimensionValueCount(List terms) { - return terms.stream() - .filter(term -> isDimensionValueDataSetId(term.nature.toString())) + return terms.stream().filter(term -> isDimensionValueDataSetId(term.nature.toString())) .count(); } private static long getDimensionCount(List terms) { return terms.stream() - .filter( - term -> - term.nature.startsWith(DictWordType.NATURE_SPILT) - && term.nature - .toString() - .endsWith(DictWordType.DIMENSION.getType())) + .filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) + && term.nature.toString().endsWith(DictWordType.DIMENSION.getType())) .count(); } private static long getMetricCount(List terms) { - return terms.stream() - .filter( - term -> - term.nature.startsWith(DictWordType.NATURE_SPILT) - && term.nature - .toString() - .endsWith(DictWordType.METRIC.getType())) - .count(); + return terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT) + && term.nature.toString().endsWith(DictWordType.METRIC.getType())).count(); } public static Map> getDataSetToNatureStat(List terms) { Map> modelToNature = new HashMap<>(); - terms.stream() - .filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT)) - .forEach( - term -> { - DictWordType dictWordType = - DictWordType.getNatureType(term.nature.toString()); - Long model = getDataSetId(term.nature.toString()); + terms.stream().filter(term -> term.nature.startsWith(DictWordType.NATURE_SPILT)) + .forEach(term -> { + DictWordType dictWordType = DictWordType.getNatureType(term.nature.toString()); + Long model = getDataSetId(term.nature.toString()); - modelToNature - .computeIfAbsent(model, k -> new HashMap<>()) - .merge(dictWordType, 1, Integer::sum); - }); + modelToNature.computeIfAbsent(model, k -> new HashMap<>()).merge(dictWordType, + 1, Integer::sum); + }); return modelToNature; } @@ -177,12 +156,9 @@ public class NatureHelper { Map> modelToNatureStat = getDataSetToNatureStat(terms); return modelToNatureStat.entrySet().stream() .max(Comparator.comparingInt(entry -> entry.getValue().size())) - .map( - entry -> - modelToNatureStat.entrySet().stream() - .filter(e -> e.getValue().size() == entry.getValue().size()) - .map(Map.Entry::getKey) - .collect(Collectors.toList())) + .map(entry -> modelToNatureStat.entrySet().stream() + .filter(e -> e.getValue().size() == entry.getValue().size()) + .map(Map.Entry::getKey).collect(Collectors.toList())) .orElse(Collections.emptyList()); } @@ -190,15 +166,14 @@ public class NatureHelper { return parseIdFromNature(nature, 2); } - public static Set getModelIds( - Map> modelIdToDataSetIds, Set detectDataSetIds) { + public static Set getModelIds(Map> modelIdToDataSetIds, + Set detectDataSetIds) { if (CollectionUtils.isEmpty(detectDataSetIds)) { return modelIdToDataSetIds.keySet(); } return modelIdToDataSetIds.entrySet().stream() .filter(entry -> !Collections.disjoint(entry.getValue(), detectDataSetIds)) - .map(Map.Entry::getKey) - .collect(Collectors.toSet()); + .map(Map.Entry::getKey).collect(Collectors.toSet()); } public static Long parseIdFromNature(String nature, int index) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java index cec1f32f8..4198a2424 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMapper.java @@ -30,9 +30,7 @@ public abstract class BaseMapper implements SchemaMapper { String simpleName = this.getClass().getSimpleName(); long startTime = System.currentTimeMillis(); - log.debug( - "before {},mapInfo:{}", - simpleName, + log.debug("before {},mapInfo:{}", simpleName, chatQueryContext.getMapInfo().getDataSetElementMatches()); try { @@ -43,17 +41,14 @@ public abstract class BaseMapper implements SchemaMapper { } long cost = System.currentTimeMillis() - startTime; - log.debug( - "after {},cost:{},mapInfo:{}", - simpleName, - cost, + log.debug("after {},cost:{},mapInfo:{}", simpleName, cost, chatQueryContext.getMapInfo().getDataSetElementMatches()); } public abstract void doMap(ChatQueryContext chatQueryContext); - public void addToSchemaMap( - SchemaMapInfo schemaMap, Long dataSetId, SchemaElementMatch newElementMatch) { + public void addToSchemaMap(SchemaMapInfo schemaMap, Long dataSetId, + SchemaElementMatch newElementMatch) { Map> dataSetElementMatches = schemaMap.getDataSetElementMatches(); List schemaElementMatches = @@ -61,26 +56,24 @@ public abstract class BaseMapper implements SchemaMapper { AtomicBoolean shouldAddNew = new AtomicBoolean(true); - schemaElementMatches.removeIf( - existingElementMatch -> { - if (isEquals(existingElementMatch, newElementMatch)) { - if (newElementMatch.getSimilarity() - > existingElementMatch.getSimilarity()) { - return true; - } else { - shouldAddNew.set(false); - } - } - return false; - }); + schemaElementMatches.removeIf(existingElementMatch -> { + if (isEquals(existingElementMatch, newElementMatch)) { + if (newElementMatch.getSimilarity() > existingElementMatch.getSimilarity()) { + return true; + } else { + shouldAddNew.set(false); + } + } + return false; + }); if (shouldAddNew.get()) { schemaElementMatches.add(newElementMatch); } } - private static boolean isEquals( - SchemaElementMatch existElementMatch, SchemaElementMatch newElementMatch) { + private static boolean isEquals(SchemaElementMatch existElementMatch, + SchemaElementMatch newElementMatch) { SchemaElement existElement = existElementMatch.getElement(); SchemaElement newElement = newElementMatch.getElement(); if (!existElement.equals(newElement)) { @@ -92,11 +85,8 @@ public abstract class BaseMapper implements SchemaMapper { return true; } - public SchemaElement getSchemaElement( - Long dataSetId, - SchemaElementType elementType, - Long elementID, - SemanticSchema semanticSchema) { + public SchemaElement getSchemaElement(Long dataSetId, SchemaElementType elementType, + Long elementID, SemanticSchema semanticSchema) { SchemaElement element = new SchemaElement(); DataSetSchema dataSetSchema = semanticSchema.getDataSetSchemaMap().get(dataSetId); if (Objects.isNull(dataSetSchema)) { @@ -124,8 +114,8 @@ public abstract class BaseMapper implements SchemaMapper { return element.getAlias(); } - public List getMatches( - ChatQueryContext chatQueryContext, BaseMatchStrategy matchStrategy) { + public List getMatches(ChatQueryContext chatQueryContext, + BaseMatchStrategy matchStrategy) { String queryText = chatQueryContext.getQueryText(); List terms = HanlpHelper.getTerms(queryText, chatQueryContext.getModelIdToDataSetIds()); @@ -136,11 +126,9 @@ public abstract class BaseMapper implements SchemaMapper { if (Objects.isNull(matchResult)) { return matches; } - Optional> first = - matchResult.entrySet().stream() - .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) - .map(entry -> entry.getValue()) - .findFirst(); + Optional> first = matchResult.entrySet().stream() + .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) + .map(entry -> entry.getValue()).findFirst(); if (first.isPresent()) { matches = first.get(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java index e83eed1a3..2aebb83f4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BaseMatchStrategy.java @@ -19,8 +19,8 @@ import java.util.Set; @Slf4j public abstract class BaseMatchStrategy implements MatchStrategy { @Override - public Map> match( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + public Map> match(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds) { String text = chatQueryContext.getQueryText(); if (Objects.isNull(terms) || StringUtils.isEmpty(text)) { return null; @@ -35,8 +35,8 @@ public abstract class BaseMatchStrategy implements MatchStr return result; } - public List detect( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + public List detect(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds) { throw new RuntimeException("Not implemented"); } @@ -46,15 +46,13 @@ public abstract class BaseMatchStrategy implements MatchStr } for (T oneRoundResult : oneRoundResults) { if (existResults.contains(oneRoundResult)) { - boolean isDeleted = - existResults.removeIf( - existResult -> { - boolean delete = existResult.lessSimilar(oneRoundResult); - if (delete) { - log.info("deleted existResult:{}", existResult); - } - return delete; - }); + boolean isDeleted = existResults.removeIf(existResult -> { + boolean delete = existResult.lessSimilar(oneRoundResult); + if (delete) { + log.info("deleted existResult:{}", existResult); + } + return delete; + }); if (isDeleted) { log.info("deleted, add oneRoundResult:{}", oneRoundResult); existResults.add(oneRoundResult); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java index 8c4f020e3..7f6a58e38 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/BatchMatchStrategy.java @@ -15,22 +15,21 @@ import java.util.Set; @Slf4j public abstract class BatchMatchStrategy extends BaseMatchStrategy { - @Autowired protected MapperConfig mapperConfig; + @Autowired + protected MapperConfig mapperConfig; @Override - public List detect( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + public List detect(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds) { String text = chatQueryContext.getQueryText(); Set detectSegments = new HashSet<>(); - int embeddingTextSize = - Integer.valueOf( - mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE)); + int embeddingTextSize = Integer + .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_SIZE)); - int embeddingTextStep = - Integer.valueOf( - mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP)); + int embeddingTextStep = Integer + .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_TEXT_STEP)); for (int startIndex = 0; startIndex < text.length(); startIndex += embeddingTextStep) { int endIndex = Math.min(startIndex + embeddingTextSize, text.length()); @@ -40,8 +39,6 @@ public abstract class BatchMatchStrategy extends BaseMatchS return detectByBatch(chatQueryContext, detectDataSetIds, detectSegments); } - public abstract List detectByBatch( - ChatQueryContext chatQueryContext, - Set detectDataSetIds, - Set detectSegments); + public abstract List detectByBatch(ChatQueryContext chatQueryContext, + Set detectDataSetIds, Set detectSegments); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java index 1218e7ffc..eaaf662d7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/DatabaseMatchStrategy.java @@ -30,17 +30,14 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy allElements; @Override - public Map> match( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + public Map> match(ChatQueryContext chatQueryContext, + List terms, Set detectDataSetIds) { this.allElements = getSchemaElements(chatQueryContext); return super.match(chatQueryContext, terms, detectDataSetIds); } - public List detectByStep( - ChatQueryContext chatQueryContext, - Set detectDataSetIds, - String detectSegment, - int offset) { + public List detectByStep(ChatQueryContext chatQueryContext, + Set detectDataSetIds, String detectSegment, int offset) { if (StringUtils.isBlank(detectSegment)) { return new ArrayList<>(); } @@ -56,13 +53,9 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy schemaElements = entry.getValue(); if (!CollectionUtils.isEmpty(detectDataSetIds)) { - schemaElements = - schemaElements.stream() - .filter( - schemaElement -> - detectDataSetIds.contains( - schemaElement.getDataSetId())) - .collect(Collectors.toSet()); + schemaElements = schemaElements.stream().filter( + schemaElement -> detectDataSetIds.contains(schemaElement.getDataSetId())) + .collect(Collectors.toSet()); } for (SchemaElement schemaElement : schemaElements) { DatabaseMapResult databaseMapResult = new DatabaseMapResult(); @@ -86,40 +79,31 @@ public class DatabaseMatchStrategy extends SingleMatchStrategy> modelElementMatches = chatQueryContext.getMapInfo().getDataSetElementMatches(); - boolean existElement = - modelElementMatches.entrySet().stream() - .anyMatch(entry -> entry.getValue().size() >= 1); + boolean existElement = modelElementMatches.entrySet().stream() + .anyMatch(entry -> entry.getValue().size() >= 1); if (!existElement) { threshold = threshold / 2; - log.debug( - "ModelElementMatches:{},not exist Element threshold reduce by half:{}", - modelElementMatches, - threshold); + log.debug("ModelElementMatches:{},not exist Element threshold reduce by half:{}", + modelElementMatches, threshold); } return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum()); } private Map> getNameToItems(List models) { - return models.stream() - .collect( - Collectors.toMap( - SchemaElement::getName, - a -> { - Set result = new HashSet<>(); - result.add(a); - return result; - }, - (k1, k2) -> { - k1.addAll(k2); - return k1; - })); + return models.stream().collect(Collectors.toMap(SchemaElement::getName, a -> { + Set result = new HashSet<>(); + result.add(a); + return result; + }, (k1, k2) -> { + k1.addAll(k2); + return k1; + })); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java index 6198ba7bb..d18950c48 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMapper.java @@ -35,23 +35,15 @@ public class EmbeddingMapper extends BaseMapper { } SchemaElementType elementType = SchemaElementType.valueOf(matchResult.getMetadata().get("type")); - SchemaElement schemaElement = - getSchemaElement( - dataSetId, - elementType, - elementId, - chatQueryContext.getSemanticSchema()); + SchemaElement schemaElement = getSchemaElement(dataSetId, elementType, elementId, + chatQueryContext.getSemanticSchema()); if (schemaElement == null) { continue; } - SchemaElementMatch schemaElementMatch = - SchemaElementMatch.builder() - .element(schemaElement) - .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) - .word(matchResult.getName()) - .similarity(matchResult.getSimilarity()) - .detectWord(matchResult.getDetectWord()) - .build(); + SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() + .element(schemaElement).frequency(BaseWordBuilder.DEFAULT_FREQUENCY) + .word(matchResult.getName()).similarity(matchResult.getSimilarity()) + .detectWord(matchResult.getDetectWord()).build(); // 3. add to mapInfo addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java index f7e343df5..ccef43890 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EmbeddingMatchStrategy.java @@ -35,21 +35,18 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.EMBEDDING @Slf4j public class EmbeddingMatchStrategy extends BatchMatchStrategy { - @Autowired private MetaEmbeddingService metaEmbeddingService; + @Autowired + private MetaEmbeddingService metaEmbeddingService; @Override - public List detectByBatch( - ChatQueryContext chatQueryContext, - Set detectDataSetIds, - Set detectSegments) { + public List detectByBatch(ChatQueryContext chatQueryContext, + Set detectDataSetIds, Set detectSegments) { Set results = new HashSet<>(); - int embeddingMapperBatch = - Integer.valueOf( - mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); + int embeddingMapperBatch = Integer + .valueOf(mapperConfig.getParameterValue(MapperConfig.EMBEDDING_MAPPER_BATCH)); List queryTextsList = - detectSegments.stream() - .map(detectSegment -> detectSegment.trim()) + detectSegments.stream().map(detectSegment -> detectSegment.trim()) .filter(detectSegment -> StringUtils.isNotBlank(detectSegment)) .collect(Collectors.toList()); @@ -64,20 +61,15 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy return new ArrayList<>(results); } - private List detectByQueryTextsSub( - Set detectDataSetIds, - List queryTextsSub, - ChatQueryContext chatQueryContext) { + private List detectByQueryTextsSub(Set detectDataSetIds, + List queryTextsSub, ChatQueryContext chatQueryContext) { Map> modelIdToDataSetIds = chatQueryContext.getModelIdToDataSetIds(); double embeddingThreshold = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD)); double embeddingThresholdMin = Double.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_THRESHOLD_MIN)); - double threshold = - getThreshold( - embeddingThreshold, - embeddingThresholdMin, - chatQueryContext.getMapModeEnum()); + double threshold = getThreshold(embeddingThreshold, embeddingThresholdMin, + chatQueryContext.getMapModeEnum()); // step1. build query params RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(queryTextsSub).build(); @@ -85,75 +77,45 @@ public class EmbeddingMatchStrategy extends BatchMatchStrategy // step2. retrieveQuery by detectSegment int embeddingNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_NUMBER)); - List retrieveQueryResults = - metaEmbeddingService.retrieveQuery( - retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds); + List retrieveQueryResults = metaEmbeddingService.retrieveQuery( + retrieveQuery, embeddingNumber, modelIdToDataSetIds, detectDataSetIds); if (CollectionUtils.isEmpty(retrieveQueryResults)) { return new ArrayList<>(); } // step3. build EmbeddingResults - List collect = - retrieveQueryResults.stream() - .map( - retrieveQueryResult -> { - List retrievals = retrieveQueryResult.getRetrieval(); - if (CollectionUtils.isNotEmpty(retrievals)) { - retrievals.removeIf( - retrieval -> { - if (!retrieveQueryResult - .getQuery() - .contains(retrieval.getQuery())) { - return retrieval.getSimilarity() - < threshold; - } - return false; - }); - } - return retrieveQueryResult; - }) - .filter( - retrieveQueryResult -> - CollectionUtils.isNotEmpty( - retrieveQueryResult.getRetrieval())) - .flatMap( - retrieveQueryResult -> - retrieveQueryResult.getRetrieval().stream() - .map( - retrieval -> { - EmbeddingResult embeddingResult = - new EmbeddingResult(); - BeanUtils.copyProperties( - retrieval, embeddingResult); - embeddingResult.setDetectWord( - retrieveQueryResult.getQuery()); - embeddingResult.setName( - retrieval.getQuery()); - Map convertedMap = - retrieval.getMetadata() - .entrySet().stream() - .collect( - Collectors - .toMap( - Map - .Entry - ::getKey, - entry -> - entry.getValue() - .toString())); - embeddingResult.setMetadata( - convertedMap); - return embeddingResult; - })) - .collect(Collectors.toList()); + List collect = retrieveQueryResults.stream().map(retrieveQueryResult -> { + List retrievals = retrieveQueryResult.getRetrieval(); + if (CollectionUtils.isNotEmpty(retrievals)) { + retrievals.removeIf(retrieval -> { + if (!retrieveQueryResult.getQuery().contains(retrieval.getQuery())) { + return retrieval.getSimilarity() < threshold; + } + return false; + }); + } + return retrieveQueryResult; + }).filter(retrieveQueryResult -> CollectionUtils + .isNotEmpty(retrieveQueryResult.getRetrieval())) + .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream() + .map(retrieval -> { + EmbeddingResult embeddingResult = new EmbeddingResult(); + BeanUtils.copyProperties(retrieval, embeddingResult); + embeddingResult.setDetectWord(retrieveQueryResult.getQuery()); + embeddingResult.setName(retrieval.getQuery()); + Map convertedMap = retrieval.getMetadata().entrySet() + .stream().collect(Collectors.toMap(Map.Entry::getKey, + entry -> entry.getValue().toString())); + embeddingResult.setMetadata(convertedMap); + return embeddingResult; + })) + .collect(Collectors.toList()); // step4. select mapResul in one round int embeddingRoundNumber = Integer.valueOf(mapperConfig.getParameterValue(EMBEDDING_MAPPER_ROUND_NUMBER)); int roundNumber = embeddingRoundNumber * queryTextsSub.size(); - return collect.stream() - .sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity)) - .limit(roundNumber) - .collect(Collectors.toList()); + return collect.stream().sorted(Comparator.comparingDouble(EmbeddingResult::getSimilarity)) + .limit(roundNumber).collect(Collectors.toList()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java index 6b1a7f7aa..138a45ea4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/EntityMapper.java @@ -31,19 +31,16 @@ public class EntityMapper extends BaseMapper { if (entity == null || entity.getId() == null) { continue; } - List valueSchemaElements = - schemaElementMatchList.stream() - .filter( - schemaElementMatch -> - SchemaElementType.VALUE.equals( - schemaElementMatch.getElement().getType())) - .collect(Collectors.toList()); + List valueSchemaElements = schemaElementMatchList.stream() + .filter(schemaElementMatch -> SchemaElementType.VALUE + .equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toList()); for (SchemaElementMatch schemaElementMatch : valueSchemaElements) { if (!entity.getId().equals(schemaElementMatch.getElement().getId())) { continue; } - if (!checkExistSameEntitySchemaElements( - schemaElementMatch, schemaElementMatchList)) { + if (!checkExistSameEntitySchemaElements(schemaElementMatch, + schemaElementMatchList)) { SchemaElementMatch entitySchemaElementMath = new SchemaElementMatch(); BeanUtils.copyProperties(schemaElementMatch, entitySchemaElementMath); entitySchemaElementMath.setElement(entity); @@ -54,20 +51,14 @@ public class EntityMapper extends BaseMapper { } } - private boolean checkExistSameEntitySchemaElements( - SchemaElementMatch valueSchemaElementMatch, + private boolean checkExistSameEntitySchemaElements(SchemaElementMatch valueSchemaElementMatch, List schemaElementMatchList) { - List entitySchemaElements = - schemaElementMatchList.stream() - .filter( - schemaElementMatch -> - SchemaElementType.ENTITY.equals( - schemaElementMatch.getElement().getType())) - .collect(Collectors.toList()); + List entitySchemaElements = schemaElementMatchList.stream() + .filter(schemaElementMatch -> SchemaElementType.ENTITY + .equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toList()); for (SchemaElementMatch schemaElementMatch : entitySchemaElements) { - if (schemaElementMatch - .getElement() - .getId() + if (schemaElementMatch.getElement().getId() .equals(valueSchemaElementMatch.getElement().getId())) { return true; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java index 94198faef..9a680f269 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/HanlpDictMatchStrategy.java @@ -26,35 +26,23 @@ import static com.tencent.supersonic.headless.chat.mapper.MapperConfig.MAPPER_DI @Slf4j public class HanlpDictMatchStrategy extends SingleMatchStrategy { - @Autowired private KnowledgeBaseService knowledgeBaseService; + @Autowired + private KnowledgeBaseService knowledgeBaseService; - public List detectByStep( - ChatQueryContext chatQueryContext, - Set detectDataSetIds, - String detectSegment, - int offset) { + public List detectByStep(ChatQueryContext chatQueryContext, + Set detectDataSetIds, String detectSegment, int offset) { // step1. pre search Integer oneDetectionMaxSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_MAX_SIZE)); - LinkedHashSet hanlpMapResults = - knowledgeBaseService - .prefixSearch( - detectSegment, - oneDetectionMaxSize, - chatQueryContext.getModelIdToDataSetIds(), - detectDataSetIds) - .stream() - .collect(Collectors.toCollection(LinkedHashSet::new)); + LinkedHashSet hanlpMapResults = knowledgeBaseService + .prefixSearch(detectSegment, oneDetectionMaxSize, + chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds) + .stream().collect(Collectors.toCollection(LinkedHashSet::new)); // step2. suffix search - LinkedHashSet suffixHanlpMapResults = - knowledgeBaseService - .suffixSearch( - detectSegment, - oneDetectionMaxSize, - chatQueryContext.getModelIdToDataSetIds(), - detectDataSetIds) - .stream() - .collect(Collectors.toCollection(LinkedHashSet::new)); + LinkedHashSet suffixHanlpMapResults = knowledgeBaseService + .suffixSearch(detectSegment, oneDetectionMaxSize, + chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds) + .stream().collect(Collectors.toCollection(LinkedHashSet::new)); hanlpMapResults.addAll(suffixHanlpMapResults); @@ -62,40 +50,28 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy return new ArrayList<>(); } // step3. merge pre/suffix result - hanlpMapResults = - hanlpMapResults.stream() - .sorted((a, b) -> -(b.getName().length() - a.getName().length())) - .collect(Collectors.toCollection(LinkedHashSet::new)); + hanlpMapResults = hanlpMapResults.stream() + .sorted((a, b) -> -(b.getName().length() - a.getName().length())) + .collect(Collectors.toCollection(LinkedHashSet::new)); // step4. filter by similarity - hanlpMapResults = - hanlpMapResults.stream() - .filter( - term -> - term.getSimilarity() - >= getThresholdMatch( - term.getNatures(), chatQueryContext)) - .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())) - .map( - parseResult -> { - parseResult.setOffset(offset); - return parseResult; - }) - .collect(Collectors.toCollection(LinkedHashSet::new)); + hanlpMapResults = hanlpMapResults.stream() + .filter(term -> term.getSimilarity() >= getThresholdMatch(term.getNatures(), + chatQueryContext)) + .filter(term -> CollectionUtils.isNotEmpty(term.getNatures())).map(parseResult -> { + parseResult.setOffset(offset); + return parseResult; + }).collect(Collectors.toCollection(LinkedHashSet::new)); - log.debug( - "detectSegment:{},after isSimilarity parseResults:{}", - detectSegment, + log.debug("detectSegment:{},after isSimilarity parseResults:{}", detectSegment, hanlpMapResults); // step5. take only M dimensionValue or N-M metric/dimension value per rond. int oneDetectionValueSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DIMENSION_VALUE_SIZE)); - List dimensionValues = - hanlpMapResults.stream() - .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) - .limit(oneDetectionValueSize) - .collect(Collectors.toList()); + List dimensionValues = hanlpMapResults.stream() + .filter(entry -> mapperHelper.existDimensionValues(entry.getNatures())) + .limit(oneDetectionValueSize).collect(Collectors.toList()); Integer oneDetectionSize = Integer.valueOf(mapperConfig.getParameterValue(MAPPER_DETECTION_SIZE)); @@ -108,14 +84,10 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy // fill the rest of the list with other results, excluding the dimensionValue if it was // added if (oneRoundResults.size() < oneDetectionSize) { - List additionalResults = - hanlpMapResults.stream() - .filter( - entry -> - !mapperHelper.existDimensionValues(entry.getNatures()) - && !oneRoundResults.contains(entry)) - .limit(oneDetectionSize - oneRoundResults.size()) - .collect(Collectors.toList()); + List additionalResults = hanlpMapResults.stream() + .filter(entry -> !mapperHelper.existDimensionValues(entry.getNatures()) + && !oneRoundResults.contains(entry)) + .limit(oneDetectionSize - oneRoundResults.size()).collect(Collectors.toList()); oneRoundResults.addAll(additionalResults); } return oneRoundResults; @@ -124,17 +96,13 @@ public class HanlpDictMatchStrategy extends SingleMatchStrategy public double getThresholdMatch(List natures, ChatQueryContext chatQueryContext) { Double threshold = Double.valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD)); - Double minThreshold = - Double.valueOf( - mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN)); + Double minThreshold = Double + .valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_NAME_THRESHOLD_MIN)); if (mapperHelper.existDimensionValues(natures)) { - threshold = - Double.valueOf( - mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD)); - minThreshold = - Double.valueOf( - mapperConfig.getParameterValue( - MapperConfig.MAPPER_VALUE_THRESHOLD_MIN)); + threshold = Double + .valueOf(mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD)); + minThreshold = Double.valueOf( + mapperConfig.getParameterValue(MapperConfig.MAPPER_VALUE_THRESHOLD_MIN)); } return getThreshold(threshold, minThreshold, chatQueryContext.getMapModeEnum()); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java index 84d64a6f1..85df0b970 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/KeywordMapper.java @@ -51,21 +51,15 @@ public class KeywordMapper extends BaseMapper { convertDatabaseMapResultToMapInfo(chatQueryContext, databaseResults); } - private void convertHanlpMapResultToMapInfo( - List mapResults, - ChatQueryContext chatQueryContext, - List terms) { + private void convertHanlpMapResultToMapInfo(List mapResults, + ChatQueryContext chatQueryContext, List terms) { if (CollectionUtils.isEmpty(mapResults)) { return; } HanlpHelper.transLetterOriginal(mapResults); - Map wordNatureToFrequency = - terms.stream() - .collect( - Collectors.toMap( - entry -> entry.getWord() + entry.getNature(), - term -> Long.valueOf(term.getFrequency()), - (value1, value2) -> value2)); + Map wordNatureToFrequency = terms.stream() + .collect(Collectors.toMap(entry -> entry.getWord() + entry.getNature(), + term -> Long.valueOf(term.getFrequency()), (value1, value2) -> value2)); for (HanlpMapResult hanlpMapResult : mapResults) { for (String nature : hanlpMapResult.getNatures()) { @@ -78,32 +72,24 @@ public class KeywordMapper extends BaseMapper { continue; } Long elementID = NatureHelper.getElementID(nature); - SchemaElement element = - getSchemaElement( - dataSetId, - elementType, - elementID, - chatQueryContext.getSemanticSchema()); + SchemaElement element = getSchemaElement(dataSetId, elementType, elementID, + chatQueryContext.getSemanticSchema()); if (element == null) { continue; } Long frequency = wordNatureToFrequency.get(hanlpMapResult.getName() + nature); - SchemaElementMatch schemaElementMatch = - SchemaElementMatch.builder() - .element(element) - .frequency(frequency) - .word(hanlpMapResult.getName()) - .similarity(hanlpMapResult.getSimilarity()) - .detectWord(hanlpMapResult.getDetectWord()) - .build(); + SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() + .element(element).frequency(frequency).word(hanlpMapResult.getName()) + .similarity(hanlpMapResult.getSimilarity()) + .detectWord(hanlpMapResult.getDetectWord()).build(); addToSchemaMap(chatQueryContext.getMapInfo(), dataSetId, schemaElementMatch); } } } - private void convertDatabaseMapResultToMapInfo( - ChatQueryContext chatQueryContext, List mapResults) { + private void convertDatabaseMapResultToMapInfo(ChatQueryContext chatQueryContext, + List mapResults) { for (DatabaseMapResult match : mapResults) { SchemaElement schemaElement = match.getSchemaElement(); Set regElementSet = @@ -111,20 +97,14 @@ public class KeywordMapper extends BaseMapper { if (regElementSet.contains(schemaElement.getId())) { continue; } - SchemaElementMatch schemaElementMatch = - SchemaElementMatch.builder() - .element(schemaElement) - .word(schemaElement.getName()) - .detectWord(match.getDetectWord()) - .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) - .similarity( - EditDistanceUtils.getSimilarity( - match.getDetectWord(), schemaElement.getName())) - .build(); + SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder() + .element(schemaElement).word(schemaElement.getName()) + .detectWord(match.getDetectWord()).frequency(BaseWordBuilder.DEFAULT_FREQUENCY) + .similarity(EditDistanceUtils.getSimilarity(match.getDetectWord(), + schemaElement.getName())) + .build(); log.info("add to schema, elementMatch {}", schemaElementMatch); - addToSchemaMap( - chatQueryContext.getMapInfo(), - schemaElement.getDataSetId(), + addToSchemaMap(chatQueryContext.getMapInfo(), schemaElement.getDataSetId(), schemaElementMatch); } } @@ -135,13 +115,9 @@ public class KeywordMapper extends BaseMapper { if (CollectionUtils.isEmpty(elements)) { return new HashSet<>(); } - return elements.stream() - .filter( - elementMatch -> - SchemaElementType.METRIC.equals(elementMatch.getElement().getType()) - || SchemaElementType.DIMENSION.equals( - elementMatch.getElement().getType())) - .map(elementMatch -> elementMatch.getElement().getId()) - .collect(Collectors.toSet()); + return elements.stream().filter( + elementMatch -> SchemaElementType.METRIC.equals(elementMatch.getElement().getType()) + || SchemaElementType.DIMENSION.equals(elementMatch.getElement().getType())) + .map(elementMatch -> elementMatch.getElement().getId()).collect(Collectors.toSet()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java index bec13b67a..948adfd15 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapFilter.java @@ -26,19 +26,16 @@ public class MapFilter { filterByQueryDataType(chatQueryContext, element -> !(element.getIsTag() > 0)); break; case METRIC: - filterByQueryDataType( - chatQueryContext, + filterByQueryDataType(chatQueryContext, element -> !SchemaElementType.METRIC.equals(element.getType())); break; case DIMENSION: - filterByQueryDataType( - chatQueryContext, - element -> { - boolean isDimensionOrValue = - SchemaElementType.DIMENSION.equals(element.getType()) - || SchemaElementType.VALUE.equals(element.getType()); - return !isDimensionOrValue; - }); + filterByQueryDataType(chatQueryContext, element -> { + boolean isDimensionOrValue = + SchemaElementType.DIMENSION.equals(element.getType()) + || SchemaElementType.VALUE.equals(element.getType()); + return !isDimensionOrValue; + }); break; case ALL: default: @@ -67,31 +64,28 @@ public class MapFilter { for (Map.Entry> entry : dataSetElementMatches.entrySet()) { List value = entry.getValue(); if (!CollectionUtils.isEmpty(value)) { - value.removeIf( - schemaElementMatch -> - StringUtils.length(schemaElementMatch.getDetectWord()) <= 1); + value.removeIf(schemaElementMatch -> StringUtils + .length(schemaElementMatch.getDetectWord()) <= 1); } } } - public static void filterByQueryDataType( - ChatQueryContext chatQueryContext, Predicate needRemovePredicate) { + public static void filterByQueryDataType(ChatQueryContext chatQueryContext, + Predicate needRemovePredicate) { Map> dataSetElementMatches = chatQueryContext.getMapInfo().getDataSetElementMatches(); for (Map.Entry> entry : dataSetElementMatches.entrySet()) { List schemaElementMatches = entry.getValue(); - schemaElementMatches.removeIf( - schemaElementMatch -> { - SchemaElement element = schemaElementMatch.getElement(); - SchemaElementType type = element.getType(); + schemaElementMatches.removeIf(schemaElementMatch -> { + SchemaElement element = schemaElementMatch.getElement(); + SchemaElementType type = element.getType(); - boolean isEntityOrDatasetOrId = - SchemaElementType.ENTITY.equals(type) - || SchemaElementType.DATASET.equals(type) - || SchemaElementType.ID.equals(type); + boolean isEntityOrDatasetOrId = SchemaElementType.ENTITY.equals(type) + || SchemaElementType.DATASET.equals(type) + || SchemaElementType.ID.equals(type); - return !isEntityOrDatasetOrId && needRemovePredicate.test(element); - }); + return !isEntityOrDatasetOrId && needRemovePredicate.test(element); + }); } } @@ -116,21 +110,16 @@ public class MapFilter { List group = entry.getValue(); // Filter out objects with similarity=1.0 - List fullMatches = - group.stream() - .filter(SchemaElementMatch::isFullMatched) - .collect(Collectors.toList()); + List fullMatches = group.stream() + .filter(SchemaElementMatch::isFullMatched).collect(Collectors.toList()); if (!fullMatches.isEmpty()) { // If there are objects with similarity=1.0, choose the one with the longest // detectWord and smallest offset - SchemaElementMatch bestMatch = - fullMatches.stream() - .max( - Comparator.comparing( - (SchemaElementMatch match) -> - match.getDetectWord().length())) - .orElse(null); + SchemaElementMatch bestMatch = fullMatches.stream() + .max(Comparator.comparing( + (SchemaElementMatch match) -> match.getDetectWord().length())) + .orElse(null); if (bestMatch != null) { result.add(bestMatch); } @@ -145,8 +134,7 @@ public class MapFilter { public static void filterInExactMatch(List matches) { Map> fullMatches = - matches.stream() - .filter(schemaElementMatch -> schemaElementMatch.isFullMatched()) + matches.stream().filter(schemaElementMatch -> schemaElementMatch.isFullMatched()) .collect(Collectors.groupingBy(SchemaElementMatch::getWord)); Set keys = new HashSet<>(fullMatches.keySet()); for (String key1 : keys) { @@ -157,8 +145,7 @@ public class MapFilter { } } List notFullMatches = - matches.stream() - .filter(schemaElementMatch -> !schemaElementMatch.isFullMatched()) + matches.stream().filter(schemaElementMatch -> !schemaElementMatch.isFullMatched()) .collect(Collectors.toList()); List mergedMatches = new ArrayList<>(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java index a3fb4ba60..bbcaecdb8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperConfig.java @@ -7,129 +7,58 @@ import org.springframework.stereotype.Service; @Service("HeadlessMapperConfig") public class MapperConfig extends ParameterConfig { - public static final Parameter MAPPER_DETECTION_SIZE = - new Parameter( - "s2.mapper.detection.size", - "8", - "一次探测返回结果个数", - "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", - "number", - "Mapper相关配置"); + public static final Parameter MAPPER_DETECTION_SIZE = new Parameter("s2.mapper.detection.size", + "8", "一次探测返回结果个数", "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的结果个数", "number", "Mapper相关配置"); public static final Parameter MAPPER_DETECTION_MAX_SIZE = - new Parameter( - "s2.mapper.detection.max.size", - "20", - "一次探测前后缀匹配结果返回个数", - "单次前后缀匹配返回的结果个数", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.detection.max.size", "20", "一次探测前后缀匹配结果返回个数", "单次前后缀匹配返回的结果个数", + "number", "Mapper相关配置"); public static final Parameter MAPPER_NAME_THRESHOLD = - new Parameter( - "s2.mapper.name.threshold", - "0.5", - "指标名、维度名文本相似度阈值", - "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.name.threshold", "0.5", "指标名、维度名文本相似度阈值", + "文本片段和匹配到的指标、维度名计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置"); public static final Parameter MAPPER_NAME_THRESHOLD_MIN = - new Parameter( - "s2.mapper.name.min.threshold", - "0.25", - "指标名、维度名最小文本相似度阈值", - "指标名、维度名相似度阈值在动态调整中的最低值", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.name.min.threshold", "0.25", "指标名、维度名最小文本相似度阈值", + "指标名、维度名相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); public static final Parameter MAPPER_DIMENSION_VALUE_SIZE = - new Parameter( - "s2.mapper.value.size", - "1", - "一次探测返回维度值结果个数", - "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的维度值结果个数", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.value.size", "1", "一次探测返回维度值结果个数", + "在每次探测后, 将前后缀匹配的结果合并, 并根据相似度阈值过滤后的维度值结果个数", "number", "Mapper相关配置"); public static final Parameter MAPPER_VALUE_THRESHOLD = - new Parameter( - "s2.mapper.value.threshold", - "0.5", - "维度值文本相似度阈值", - "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.value.threshold", "0.5", "维度值文本相似度阈值", + "文本片段和匹配到的维度值计算出来的编辑距离阈值, 若超出该阈值, 则舍弃", "number", "Mapper相关配置"); public static final Parameter MAPPER_VALUE_THRESHOLD_MIN = - new Parameter( - "s2.mapper.value.min.threshold", - "0.3", - "维度值最小文本相似度阈值", - "维度值相似度阈值在动态调整中的最低值", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.value.min.threshold", "0.3", "维度值最小文本相似度阈值", + "维度值相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_TEXT_SIZE = - new Parameter( - "s2.mapper.embedding.word.size", - "4", - "用于向量召回文本长度", - "为提高向量召回效率, 按指定长度进行向量语义召回", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.word.size", "4", "用于向量召回文本长度", + "为提高向量召回效率, 按指定长度进行向量语义召回", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_TEXT_STEP = - new Parameter( - "s2.mapper.embedding.word.step", - "3", - "向量召回文本每步长度", - "为提高向量召回效率, 按指定每步长度进行召回", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.word.step", "3", "向量召回文本每步长度", + "为提高向量召回效率, 按指定每步长度进行召回", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_BATCH = - new Parameter( - "s2.mapper.embedding.batch", - "50", - "批量向量召回文本请求个数", - "每次进行向量语义召回的原始文本片段个数", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.batch", "50", "批量向量召回文本请求个数", "每次进行向量语义召回的原始文本片段个数", + "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_NUMBER = - new Parameter( - "s2.mapper.embedding.number", - "5", - "批量向量召回文本返回结果个数", - "每个文本进行向量语义召回的文本结果个数", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.number", "5", "批量向量召回文本返回结果个数", + "每个文本进行向量语义召回的文本结果个数", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_THRESHOLD = - new Parameter( - "s2.mapper.embedding.threshold", - "0.98", - "向量召回相似度阈值", - "相似度小于该阈值的则舍弃", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.threshold", "0.98", "向量召回相似度阈值", "相似度小于该阈值的则舍弃", + "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_THRESHOLD_MIN = - new Parameter( - "s2.mapper.embedding.min.threshold", - "0.9", - "向量召回最小相似度阈值", - "向量召回相似度阈值在动态调整中的最低值", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.min.threshold", "0.9", "向量召回最小相似度阈值", + "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); public static final Parameter EMBEDDING_MAPPER_ROUND_NUMBER = - new Parameter( - "s2.mapper.embedding.round.number", - "10", - "向量召回最小相似度阈值", - "向量召回相似度阈值在动态调整中的最低值", - "number", - "Mapper相关配置"); + new Parameter("s2.mapper.embedding.round.number", "10", "向量召回最小相似度阈值", + "向量召回相似度阈值在动态调整中的最低值", "number", "Mapper相关配置"); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java index 5566bb00a..8a00ab58b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MapperHelper.java @@ -28,11 +28,8 @@ public class MapperHelper { } public Integer getStepOffset(List termList, Integer index) { - List offsetList = - termList.stream() - .sorted(Comparator.comparing(S2Term::getOffset)) - .map(term -> term.getOffset()) - .collect(Collectors.toList()); + List offsetList = termList.stream().sorted(Comparator.comparing(S2Term::getOffset)) + .map(term -> term.getOffset()).collect(Collectors.toList()); for (int j = 0; j < termList.size() - 1; j++) { if (offsetList.get(j) <= index && offsetList.get(j + 1) > index) { @@ -43,13 +40,8 @@ public class MapperHelper { } public Map getRegOffsetToLength(List terms) { - return terms.stream() - .sorted(Comparator.comparing(S2Term::length)) - .collect( - Collectors.toMap( - S2Term::getOffset, - term -> term.word.length(), - (value1, value2) -> value2)); + return terms.stream().sorted(Comparator.comparing(S2Term::length)).collect(Collectors + .toMap(S2Term::getOffset, term -> term.word.length(), (value1, value2) -> value2)); } /** diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java index 049b138f8..d97b76d80 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/MatchStrategy.java @@ -13,6 +13,6 @@ import java.util.Set; */ public interface MatchStrategy { - Map> match( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds); + Map> match(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java index da75b37d9..e9c168207 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/QueryFilterMapper.java @@ -43,17 +43,15 @@ public class QueryFilterMapper extends BaseMapper { } private void clearOtherSchemaElementMatch(Set viewIds, SchemaMapInfo schemaMapInfo) { - for (Map.Entry> entry : - schemaMapInfo.getDataSetElementMatches().entrySet()) { + for (Map.Entry> entry : schemaMapInfo + .getDataSetElementMatches().entrySet()) { if (!viewIds.contains(entry.getKey())) { entry.getValue().clear(); } } } - private void addValueSchemaElementMatch( - Long dataSetId, - ChatQueryContext chatQueryContext, + private void addValueSchemaElementMatch(Long dataSetId, ChatQueryContext chatQueryContext, List candidateElementMatches) { QueryFilters queryFilters = chatQueryContext.getQueryFilters(); if (queryFilters == null || CollectionUtils.isEmpty(queryFilters.getFilters())) { @@ -63,40 +61,27 @@ public class QueryFilterMapper extends BaseMapper { if (checkExistSameValueSchemaElementMatch(filter, candidateElementMatches)) { continue; } - SchemaElement element = - SchemaElement.builder() - .id(filter.getElementID()) - .name(String.valueOf(filter.getValue())) - .type(SchemaElementType.VALUE) - .bizName(filter.getBizName()) - .dataSetId(dataSetId) - .build(); - SchemaElementMatch schemaElementMatch = - SchemaElementMatch.builder() - .element(element) - .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) - .word(String.valueOf(filter.getValue())) - .similarity(similarity) - .detectWord(Constants.EMPTY) - .build(); + SchemaElement element = SchemaElement.builder().id(filter.getElementID()) + .name(String.valueOf(filter.getValue())).type(SchemaElementType.VALUE) + .bizName(filter.getBizName()).dataSetId(dataSetId).build(); + SchemaElementMatch schemaElementMatch = SchemaElementMatch.builder().element(element) + .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) + .word(String.valueOf(filter.getValue())).similarity(similarity) + .detectWord(Constants.EMPTY).build(); candidateElementMatches.add(schemaElementMatch); } chatQueryContext.getMapInfo().setMatchedElements(dataSetId, candidateElementMatches); } - private boolean checkExistSameValueSchemaElementMatch( - QueryFilter queryFilter, List schemaElementMatches) { - List valueSchemaElements = - schemaElementMatches.stream() - .filter( - schemaElementMatch -> - SchemaElementType.VALUE.equals( - schemaElementMatch.getElement().getType())) - .collect(Collectors.toList()); + private boolean checkExistSameValueSchemaElementMatch(QueryFilter queryFilter, + List schemaElementMatches) { + List valueSchemaElements = schemaElementMatches.stream() + .filter(schemaElementMatch -> SchemaElementType.VALUE + .equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toList()); for (SchemaElementMatch schemaElementMatch : valueSchemaElements) { if (schemaElementMatch.getElement().getId().equals(queryFilter.getElementID()) - && schemaElementMatch - .getWord() + && schemaElementMatch.getWord() .equals(String.valueOf(queryFilter.getValue()))) { return true; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java index 4afd02b9c..913c0e619 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SearchMatchStrategy.java @@ -27,19 +27,21 @@ public class SearchMatchStrategy extends BaseMatchStrategy { private static final int SEARCH_SIZE = 3; - @Autowired private KnowledgeBaseService knowledgeBaseService; + @Autowired + private KnowledgeBaseService knowledgeBaseService; - @Autowired private MapperHelper mapperHelper; + @Autowired + private MapperHelper mapperHelper; @Override - public Map> match( - ChatQueryContext chatQueryContext, List originals, Set detectDataSetIds) { + public Map> match(ChatQueryContext chatQueryContext, + List originals, Set detectDataSetIds) { String text = chatQueryContext.getQueryText(); Map regOffsetToLength = mapperHelper.getRegOffsetToLength(originals); List detectIndexList = Lists.newArrayList(); - for (Integer index = 0; index < text.length(); ) { + for (Integer index = 0; index < text.length();) { if (index < text.length()) { detectIndexList.add(index); @@ -52,58 +54,33 @@ public class SearchMatchStrategy extends BaseMatchStrategy { } } Map> regTextMap = new ConcurrentHashMap<>(); - detectIndexList.stream() - .parallel() - .forEach( - detectIndex -> { - String regText = text.substring(0, detectIndex); - String detectSegment = text.substring(detectIndex); + detectIndexList.stream().parallel().forEach(detectIndex -> { + String regText = text.substring(0, detectIndex); + String detectSegment = text.substring(detectIndex); - if (StringUtils.isNotEmpty(detectSegment)) { - List hanlpMapResults = - knowledgeBaseService.prefixSearch( - detectSegment, - SearchService.SEARCH_SIZE, - chatQueryContext.getModelIdToDataSetIds(), - detectDataSetIds); - List suffixHanlpMapResults = - knowledgeBaseService.suffixSearch( - detectSegment, - SEARCH_SIZE, - chatQueryContext.getModelIdToDataSetIds(), - detectDataSetIds); - hanlpMapResults.addAll(suffixHanlpMapResults); - // remove entity name where search - hanlpMapResults = - hanlpMapResults.stream() - .filter( - entry -> { - List natures = - entry.getNatures().stream() - .filter( - nature -> - !nature - .endsWith( - DictWordType - .ENTITY - .getType())) - .collect( - Collectors - .toList()); - if (CollectionUtils.isEmpty(natures)) { - return false; - } - return true; - }) - .collect(Collectors.toList()); - MatchText matchText = - MatchText.builder() - .regText(regText) - .detectSegment(detectSegment) - .build(); - regTextMap.put(matchText, hanlpMapResults); - } - }); + if (StringUtils.isNotEmpty(detectSegment)) { + List hanlpMapResults = + knowledgeBaseService.prefixSearch(detectSegment, SearchService.SEARCH_SIZE, + chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds); + List suffixHanlpMapResults = + knowledgeBaseService.suffixSearch(detectSegment, SEARCH_SIZE, + chatQueryContext.getModelIdToDataSetIds(), detectDataSetIds); + hanlpMapResults.addAll(suffixHanlpMapResults); + // remove entity name where search + hanlpMapResults = hanlpMapResults.stream().filter(entry -> { + List natures = entry.getNatures().stream() + .filter(nature -> !nature.endsWith(DictWordType.ENTITY.getType())) + .collect(Collectors.toList()); + if (CollectionUtils.isEmpty(natures)) { + return false; + } + return true; + }).collect(Collectors.toList()); + MatchText matchText = + MatchText.builder().regText(regText).detectSegment(detectSegment).build(); + regTextMap.put(matchText, hanlpMapResults); + } + }); return regTextMap; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java index 6b1b69b5e..9b3077ce9 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/mapper/SingleMatchStrategy.java @@ -16,20 +16,22 @@ import java.util.Set; @Service @Slf4j public abstract class SingleMatchStrategy extends BaseMatchStrategy { - @Autowired protected MapperConfig mapperConfig; - @Autowired protected MapperHelper mapperHelper; + @Autowired + protected MapperConfig mapperConfig; + @Autowired + protected MapperHelper mapperHelper; - public List detect( - ChatQueryContext chatQueryContext, List terms, Set detectDataSetIds) { + public List detect(ChatQueryContext chatQueryContext, List terms, + Set detectDataSetIds) { Map regOffsetToLength = mapperHelper.getRegOffsetToLength(terms); String text = chatQueryContext.getQueryText(); Set results = new HashSet<>(); Set detectSegments = new HashSet<>(); - for (Integer startIndex = 0; startIndex <= text.length() - 1; ) { + for (Integer startIndex = 0; startIndex <= text.length() - 1;) { - for (Integer index = startIndex; index <= text.length(); ) { + for (Integer index = startIndex; index <= text.length();) { int offset = mapperHelper.getStepOffset(terms, startIndex); index = mapperHelper.getStepIndex(regOffsetToLength, index); if (index <= text.length()) { @@ -45,9 +47,6 @@ public abstract class SingleMatchStrategy extends BaseMatch return new ArrayList<>(results); } - public abstract List detectByStep( - ChatQueryContext chatQueryContext, - Set detectDataSetIds, - String detectSegment, - int offset); + public abstract List detectByStep(ChatQueryContext chatQueryContext, + Set detectDataSetIds, String detectSegment, int offset); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java index 8c4ccfe42..90c7fd3a4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/ParserConfig.java @@ -13,89 +13,45 @@ import java.util.List; public class ParserConfig extends ParameterConfig { public static final Parameter PARSER_STRATEGY_TYPE = - new Parameter( - "s2.parser.s2sql.strategy", - "ONE_PASS_SELF_CONSISTENCY", - "LLM解析生成S2SQL策略", - "ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", - "list", - "Parser相关配置", + new Parameter("s2.parser.s2sql.strategy", "ONE_PASS_SELF_CONSISTENCY", "LLM解析生成S2SQL策略", + "ONE_PASS_SELF_CONSISTENCY: 通过投票方式一步生成sql", "list", "Parser相关配置", Lists.newArrayList("ONE_PASS_SELF_CONSISTENCY")); public static final Parameter PARSER_LINKING_VALUE_ENABLE = - new Parameter( - "s2.parser.linking.value.enable", - "true", - "是否将Mapper探测识别到的维度值提供给大模型", - "为了数据安全考虑, 这里可进行开关选择", - "bool", - "Parser相关配置"); + new Parameter("s2.parser.linking.value.enable", "true", "是否将Mapper探测识别到的维度值提供给大模型", + "为了数据安全考虑, 这里可进行开关选择", "bool", "Parser相关配置"); public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD = - new Parameter( - "s2.parser.text.length.threshold", - "10", - "用户输入文本长短阈值", - "文本超过该阈值为长文本", - "number", - "Parser相关配置"); + new Parameter("s2.parser.text.length.threshold", "10", "用户输入文本长短阈值", "文本超过该阈值为长文本", + "number", "Parser相关配置"); public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_SHORT = - new Parameter( - "s2.parser.text.threshold.short", - "0.5", - "短文本匹配阈值", + new Parameter("s2.parser.text.threshold.short", "0.5", "短文本匹配阈值", "由于请求大模型耗时较长, 因此如果有规则类型的Query得分达到阈值,则跳过大模型的调用," + "\n如果是短文本, 若query得分/文本长度>该阈值, 则跳过当前parser", - "number", - "Parser相关配置"); + "number", "Parser相关配置"); public static final Parameter PARSER_TEXT_LENGTH_THRESHOLD_LONG = - new Parameter( - "s2.parser.text.threshold.long", - "0.8", - "长文本匹配阈值", - "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", - "number", - "Parser相关配置"); + new Parameter("s2.parser.text.threshold.long", "0.8", "长文本匹配阈值", + "如果是长文本, 若query得分/文本长度>该阈值, 则跳过当前parser", "number", "Parser相关配置"); - public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = - new Parameter( - "s2.parser.exemplar-recall.number", - "10", - "exemplar召回个数", - "", - "number", - "Parser相关配置"); + public static final Parameter PARSER_EXEMPLAR_RECALL_NUMBER = new Parameter( + "s2.parser.exemplar-recall.number", "10", "exemplar召回个数", "", "number", "Parser相关配置"); public static final Parameter PARSER_FEW_SHOT_NUMBER = - new Parameter( - "s2.parser.few-shot.number", - "3", - "few-shot样例个数", - "样例越多效果可能越好,但token消耗越大", - "number", - "Parser相关配置"); + new Parameter("s2.parser.few-shot.number", "3", "few-shot样例个数", "样例越多效果可能越好,但token消耗越大", + "number", "Parser相关配置"); public static final Parameter PARSER_SELF_CONSISTENCY_NUMBER = - new Parameter( - "s2.parser.self-consistency.number", - "1", - "self-consistency执行个数", - "执行越多效果可能越好,但token消耗越大", - "number", - "Parser相关配置"); + new Parameter("s2.parser.self-consistency.number", "1", "self-consistency执行个数", + "执行越多效果可能越好,但token消耗越大", "number", "Parser相关配置"); - public static final Parameter PARSER_SHOW_COUNT = - new Parameter( - "s2.parser.show.count", "3", "解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置"); + public static final Parameter PARSER_SHOW_COUNT = new Parameter("s2.parser.show.count", "3", + "解析结果展示个数", "前端展示的解析个数", "number", "Parser相关配置"); @Override public List getSysParameters() { - return Lists.newArrayList( - PARSER_LINKING_VALUE_ENABLE, - PARSER_FEW_SHOT_NUMBER, - PARSER_SELF_CONSISTENCY_NUMBER, - PARSER_SHOW_COUNT); + return Lists.newArrayList(PARSER_LINKING_VALUE_ENABLE, PARSER_FEW_SHOT_NUMBER, + PARSER_SELF_CONSISTENCY_NUMBER, PARSER_SHOW_COUNT); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java index 3e4c09da4..031dc97ba 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/QueryTypeParser.java @@ -59,10 +59,8 @@ public class QueryTypeParser implements SemanticParser { List whereFields = SqlSelectHelper.getWhereFields(sqlInfo.getParsedS2SQL()); List whereFilterByTimeFields = filterByTimeFields(whereFields); if (CollectionUtils.isNotEmpty(whereFilterByTimeFields)) { - Set ids = - semanticSchema.getEntities(dataSetId).stream() - .map(SchemaElement::getName) - .collect(Collectors.toSet()); + Set ids = semanticSchema.getEntities(dataSetId).stream() + .map(SchemaElement::getName).collect(Collectors.toSet()); if (CollectionUtils.isNotEmpty(ids) && ids.stream().anyMatch(whereFilterByTimeFields::contains)) { return QueryType.ID; @@ -80,15 +78,14 @@ public class QueryTypeParser implements SemanticParser { } private static List filterByTimeFields(List whereFields) { - List selectAndWhereFilterByTimeFields = - whereFields.stream() - .filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) - .collect(Collectors.toList()); + List selectAndWhereFilterByTimeFields = whereFields.stream() + .filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) + .collect(Collectors.toList()); return selectAndWhereFilterByTimeFields; } - private static boolean selectContainsMetric( - SqlInfo sqlInfo, Long dataSetId, SemanticSchema semanticSchema) { + private static boolean selectContainsMetric(SqlInfo sqlInfo, Long dataSetId, + SemanticSchema semanticSchema) { List selectFields = SqlSelectHelper.getSelectFields(sqlInfo.getParsedS2SQL()); List metrics = semanticSchema.getMetrics(dataSetId); if (CollectionUtils.isNotEmpty(metrics)) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java index 52898039c..0bfd40e22 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/SatisfactionChecker.java @@ -50,10 +50,7 @@ public class SatisfactionChecker { } else if (degree < shortTextLengthThreshold) { return false; } - log.info( - "queryMode:{}, degree:{}, parse info:{}", - semanticParseInfo.getQueryMode(), - degree, + log.info("queryMode:{}, degree:{}, parse info:{}", semanticParseInfo.getQueryMode(), degree, semanticParseInfo); return true; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java index b745b9925..79790fea1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/HeuristicDataSetResolver.java @@ -37,26 +37,19 @@ public class HeuristicDataSetResolver implements DataSetResolver { protected Long selectDataSetByMatchSimilarity(SchemaMapInfo schemaMap) { Map dataSetMatchRet = getDataSetMatchResult(schemaMap); Entry selectedDataset = - dataSetMatchRet.entrySet().stream() - .sorted( - (o1, o2) -> { - double difference = - o1.getValue().getMaxDatesetSimilarity() - - o2.getValue().getMaxDatesetSimilarity(); - if (difference == 0) { - difference = - o1.getValue().getMaxMetricSimilarity() - - o2.getValue().getMaxMetricSimilarity(); - if (difference == 0) { - difference = - o1.getValue().getTotalSimilarity() - - o2.getValue().getTotalSimilarity(); - } - } - return difference >= 0 ? -1 : 1; - }) - .findFirst() - .orElse(null); + dataSetMatchRet.entrySet().stream().sorted((o1, o2) -> { + double difference = o1.getValue().getMaxDatesetSimilarity() + - o2.getValue().getMaxDatesetSimilarity(); + if (difference == 0) { + difference = o1.getValue().getMaxMetricSimilarity() + - o2.getValue().getMaxMetricSimilarity(); + if (difference == 0) { + difference = o1.getValue().getTotalSimilarity() + - o2.getValue().getTotalSimilarity(); + } + } + return difference >= 0 ? -1 : 1; + }).findFirst().orElse(null); if (selectedDataset != null) { log.info("selectDataSet with multiple DataSets [{}]", selectedDataset.getKey()); return selectedDataset.getKey(); @@ -67,8 +60,8 @@ public class HeuristicDataSetResolver implements DataSetResolver { protected Map getDataSetMatchResult(SchemaMapInfo schemaMap) { Map dateSetMatchRet = new HashMap<>(); - for (Entry> entry : - schemaMap.getDataSetElementMatches().entrySet()) { + for (Entry> entry : schemaMap.getDataSetElementMatches() + .entrySet()) { double maxMetricSimilarity = 0; double maxDatasetSimilarity = 0; double totalSimilarity = 0; @@ -81,13 +74,10 @@ public class HeuristicDataSetResolver implements DataSetResolver { } totalSimilarity += match.getSimilarity(); } - dateSetMatchRet.put( - entry.getKey(), - DataSetMatchResult.builder() - .maxMetricSimilarity(maxMetricSimilarity) + dateSetMatchRet.put(entry.getKey(), + DataSetMatchResult.builder().maxMetricSimilarity(maxMetricSimilarity) .maxDatesetSimilarity(maxDatasetSimilarity) - .totalSimilarity(totalSimilarity) - .build()); + .totalSimilarity(totalSimilarity).build()); } return dateSetMatchRet; diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java index 1513f128c..4ae7c2548 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMRequestService.java @@ -31,7 +31,8 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_ST @Service public class LLMRequestService { - @Autowired private ParserConfig parserConfig; + @Autowired + private ParserConfig parserConfig; public boolean isSkip(ChatQueryContext queryCtx) { if (!queryCtx.getText2SQLType().enableLLM()) { @@ -95,88 +96,63 @@ public class LLMRequestService { if (CollectionUtils.isEmpty(matchedElements)) { return new ArrayList<>(); } - return matchedElements.stream() - .filter( - schemaElementMatch -> { - SchemaElementType elementType = - schemaElementMatch.getElement().getType(); - return SchemaElementType.TERM.equals(elementType); - }) - .map( - schemaElementMatch -> { - LLMReq.Term term = new LLMReq.Term(); - term.setName(schemaElementMatch.getElement().getName()); - term.setDescription(schemaElementMatch.getElement().getDescription()); - term.setAlias(schemaElementMatch.getElement().getAlias()); - return term; - }) - .collect(Collectors.toList()); + return matchedElements.stream().filter(schemaElementMatch -> { + SchemaElementType elementType = schemaElementMatch.getElement().getType(); + return SchemaElementType.TERM.equals(elementType); + }).map(schemaElementMatch -> { + LLMReq.Term term = new LLMReq.Term(); + term.setName(schemaElementMatch.getElement().getName()); + term.setDescription(schemaElementMatch.getElement().getDescription()); + term.setAlias(schemaElementMatch.getElement().getAlias()); + return term; + }).collect(Collectors.toList()); } - protected List getMappedValues( - @NotNull ChatQueryContext queryCtx, Long dataSetId) { + protected List getMappedValues(@NotNull ChatQueryContext queryCtx, + Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { return new ArrayList<>(); } - Set valueMatches = - matchedElements.stream() - .filter(elementMatch -> !elementMatch.isInherited()) - .filter( - schemaElementMatch -> { - SchemaElementType type = - schemaElementMatch.getElement().getType(); - return SchemaElementType.VALUE.equals(type) - || SchemaElementType.ID.equals(type); - }) - .map( - elementMatch -> { - LLMReq.ElementValue elementValue = new LLMReq.ElementValue(); - elementValue.setFieldName(elementMatch.getElement().getName()); - elementValue.setFieldValue(elementMatch.getWord()); - return elementValue; - }) - .collect(Collectors.toSet()); + Set valueMatches = matchedElements.stream() + .filter(elementMatch -> !elementMatch.isInherited()).filter(schemaElementMatch -> { + SchemaElementType type = schemaElementMatch.getElement().getType(); + return SchemaElementType.VALUE.equals(type) + || SchemaElementType.ID.equals(type); + }).map(elementMatch -> { + LLMReq.ElementValue elementValue = new LLMReq.ElementValue(); + elementValue.setFieldName(elementMatch.getElement().getName()); + elementValue.setFieldValue(elementMatch.getWord()); + return elementValue; + }).collect(Collectors.toSet()); return new ArrayList<>(valueMatches); } - protected List getMappedMetrics( - @NotNull ChatQueryContext queryCtx, Long dataSetId) { + protected List getMappedMetrics(@NotNull ChatQueryContext queryCtx, + Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); if (CollectionUtils.isEmpty(matchedElements)) { return Collections.emptyList(); } - List schemaElements = - matchedElements.stream() - .filter( - schemaElementMatch -> { - SchemaElementType elementType = - schemaElementMatch.getElement().getType(); - return SchemaElementType.METRIC.equals(elementType); - }) - .map( - schemaElementMatch -> { - return schemaElementMatch.getElement(); - }) - .collect(Collectors.toList()); + List schemaElements = matchedElements.stream().filter(schemaElementMatch -> { + SchemaElementType elementType = schemaElementMatch.getElement().getType(); + return SchemaElementType.METRIC.equals(elementType); + }).map(schemaElementMatch -> { + return schemaElementMatch.getElement(); + }).collect(Collectors.toList()); return schemaElements; } - protected List getMappedDimensions( - @NotNull ChatQueryContext queryCtx, Long dataSetId) { + protected List getMappedDimensions(@NotNull ChatQueryContext queryCtx, + Long dataSetId) { List matchedElements = queryCtx.getMapInfo().getMatchedElements(dataSetId); - List dimensionElements = - matchedElements.stream() - .filter( - element -> - SchemaElementType.DIMENSION.equals( - element.getElement().getType())) - .map(SchemaElementMatch::getElement) - .collect(Collectors.toList()); + List dimensionElements = matchedElements.stream().filter( + element -> SchemaElementType.DIMENSION.equals(element.getElement().getType())) + .map(SchemaElementMatch::getElement).collect(Collectors.toList()); return new ArrayList<>(dimensionElements); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java index e8a302158..6aa750505 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMResponseService.java @@ -23,8 +23,8 @@ import java.util.Objects; @Service public class LLMResponseService { - public SemanticParseInfo addParseInfo( - ChatQueryContext queryCtx, ParseResult parseResult, String s2SQL, Double weight) { + public SemanticParseInfo addParseInfo(ChatQueryContext queryCtx, ParseResult parseResult, + String s2SQL, Double weight) { if (Objects.isNull(weight)) { weight = 0D; } @@ -33,20 +33,16 @@ public class LLMResponseService { parseInfo.setDataSet(queryCtx.getSemanticSchema().getDataSet(parseResult.getDataSetId())); parseInfo.setQueryConfig( queryCtx.getSemanticSchema().getQueryConfig(parseResult.getDataSetId())); - parseInfo - .getElementMatches() + parseInfo.getElementMatches() .addAll(queryCtx.getMapInfo().getMatchedElements(parseInfo.getDataSetId())); Map properties = new HashMap<>(); properties.put(Constants.CONTEXT, parseResult); properties.put("type", "internal"); - Text2SQLExemplar exemplar = - Text2SQLExemplar.builder() - .question(queryCtx.getQueryText()) - .sideInfo(parseResult.getLlmResp().getSideInfo()) - .dbSchema(parseResult.getLlmResp().getSchema()) - .sql(parseResult.getLlmResp().getSqlOutput()) - .build(); + Text2SQLExemplar exemplar = Text2SQLExemplar.builder().question(queryCtx.getQueryText()) + .sideInfo(parseResult.getLlmResp().getSideInfo()) + .dbSchema(parseResult.getLlmResp().getSchema()) + .sql(parseResult.getLlmResp().getSqlOutput()).build(); properties.put(Text2SQLExemplar.PROPERTY_KEY, exemplar); parseInfo.setProperties(properties); parseInfo.setScore(queryCtx.getQueryText().length() * (1 + weight)); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java index aec35e0c1..556ca59eb 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/LLMSqlParser.java @@ -61,12 +61,8 @@ public class LLMSqlParser implements SemanticParser { // deduplicate the S2SQL result list and build parserInfo sqlRespMap = responseService.getDeduplicationSqlResp(currentRetry, llmResp); if (MapUtils.isNotEmpty(sqlRespMap)) { - parseResult = - ParseResult.builder() - .dataSetId(dataSetId) - .llmReq(llmReq) - .llmResp(llmResp) - .build(); + parseResult = ParseResult.builder().dataSetId(dataSetId).llmReq(llmReq) + .llmResp(llmResp).build(); break; } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java index f33e2b0dc..ed128c36b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/OnePassSCSqlGenStrategy.java @@ -25,21 +25,19 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class OnePassSCSqlGenStrategy extends SqlGenStrategy { - public static final String INSTRUCTION = - "" - + "\n#Role: You are a data analyst experienced in SQL languages." - + "\n#Task: You will be provided with a natural language question asked by users," - + "please convert it to a SQL query so that relevant data could be returned " - + "by executing the SQL query against underlying database." - + "\n#Rules:" - + "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate." - + "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." - + "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." - + "\n4.DO NOT calculate date range using functions." - + "\n5.DO NOT calculate date range using DATE_SUB." - + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." - + "\n#Exemplars:\n{{exemplar}}" - + "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; + public static final String INSTRUCTION = "" + + "\n#Role: You are a data analyst experienced in SQL languages." + + "\n#Task: You will be provided with a natural language question asked by users," + + "please convert it to a SQL query so that relevant data could be returned " + + "by executing the SQL query against underlying database." + "\n#Rules:" + + "\n1.ALWAYS generate columns and values specified in the `Schema`, DO NOT hallucinate." + + "\n2.ALWAYS specify date filter using `>`,`<`,`>=`,`<=` operator." + + "\n3.DO NOT include date filter in the where clause if not explicitly expressed in the `Question`." + + "\n4.DO NOT calculate date range using functions." + + "\n5.DO NOT calculate date range using DATE_SUB." + + "\n6.DO NOT miss the AGGREGATE operator of metrics, always add it as needed." + + "\n#Exemplars:\n{{exemplar}}" + + "\n#Question:\nQuestion:{{question}},Schema:{{schema}},SideInfo:{{information}}"; @Data static class SemanticSql { @@ -75,21 +73,12 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { // 3.perform multiple self-consistency inferences parallelly Map output2Prompt = new ConcurrentHashMap<>(); - prompt2Exemplar - .keySet() - .parallelStream() - .forEach( - prompt -> { - keyPipelineLog.info( - "OnePassSCSqlGenStrategy reqPrompt:\n{}", - prompt.toUserMessage()); - SemanticSql s2Sql = - extractor.generateSemanticSql( - prompt.toUserMessage().singleText()); - output2Prompt.put(s2Sql.getSql(), prompt); - keyPipelineLog.info( - "OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql()); - }); + prompt2Exemplar.keySet().parallelStream().forEach(prompt -> { + keyPipelineLog.info("OnePassSCSqlGenStrategy reqPrompt:\n{}", prompt.toUserMessage()); + SemanticSql s2Sql = extractor.generateSemanticSql(prompt.toUserMessage().singleText()); + output2Prompt.put(s2Sql.getSql(), prompt); + keyPipelineLog.info("OnePassSCSqlGenStrategy modelResp:\n{}", s2Sql.getSql()); + }); // 4.format response. Pair> sqlMapPair = @@ -105,13 +94,9 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { private Prompt generatePrompt(LLMReq llmReq, LLMResp llmResp) { StringBuilder exemplars = new StringBuilder(); for (Text2SQLExemplar exemplar : llmReq.getDynamicExemplars()) { - String exemplarStr = - String.format( - "Question:%s,Schema:%s,SideInfo:%s,SQL:%s\n", - exemplar.getQuestion(), - exemplar.getDbSchema(), - exemplar.getSideInfo(), - exemplar.getSql()); + String exemplarStr = String.format("Question:%s,Schema:%s,SideInfo:%s,SQL:%s\n", + exemplar.getQuestion(), exemplar.getDbSchema(), exemplar.getSideInfo(), + exemplar.getSql()); exemplars.append(exemplarStr); } String dataSemantics = promptHelper.buildSchemaStr(llmReq); @@ -136,7 +121,7 @@ public class OnePassSCSqlGenStrategy extends SqlGenStrategy { @Override public void afterPropertiesSet() { - SqlGenStrategyFactory.addSqlGenerationForFactory( - LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this); + SqlGenStrategyFactory + .addSqlGenerationForFactory(LLMReq.SqlGenType.ONE_PASS_SELF_CONSISTENCY, this); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java index 9d1f91bd3..750c3ba72 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/PromptHelper.java @@ -24,9 +24,11 @@ import static com.tencent.supersonic.headless.chat.parser.ParserConfig.PARSER_SE @Slf4j public class PromptHelper { - @Autowired private ParserConfig parserConfig; + @Autowired + private ParserConfig parserConfig; - @Autowired private ExemplarService exemplarService; + @Autowired + private ExemplarService exemplarService; public List> getFewShotExemplars(LLMReq llmReq) { int exemplarRecallNumber = @@ -36,11 +38,9 @@ public class PromptHelper { Integer.valueOf(parserConfig.getParameterValue(PARSER_SELF_CONSISTENCY_NUMBER)); List exemplars = Lists.newArrayList(); - llmReq.getDynamicExemplars().stream() - .forEach( - e -> { - exemplars.add(e); - }); + llmReq.getDynamicExemplars().stream().forEach(e -> { + exemplars.add(e); + }); int recallSize = exemplarRecallNumber - llmReq.getDynamicExemplars().size(); if (recallSize > 0) { @@ -79,81 +79,65 @@ public class PromptHelper { String tableStr = llmReq.getSchema().getDataSetName(); List metrics = Lists.newArrayList(); - llmReq.getSchema().getMetrics().stream() - .forEach( - metric -> { - StringBuilder metricStr = new StringBuilder(); - metricStr.append("<"); - metricStr.append(metric.getName()); - if (!CollectionUtils.isEmpty(metric.getAlias())) { - StringBuilder alias = new StringBuilder(); - metric.getAlias().stream().forEach(a -> alias.append(a + ",")); - metricStr.append(" ALIAS '" + alias + "'"); - } - if (StringUtils.isNotEmpty(metric.getDataFormatType())) { - String dataFormatType = metric.getDataFormatType(); - if (DataFormatTypeEnum.DECIMAL - .getName() - .equalsIgnoreCase(dataFormatType) - || DataFormatTypeEnum.PERCENT - .getName() - .equalsIgnoreCase(dataFormatType)) { - metricStr.append(" FORMAT '" + dataFormatType + "'"); - } - } - if (StringUtils.isNotEmpty(metric.getDescription())) { - metricStr.append(" COMMENT '" + metric.getDescription() + "'"); - } - if (StringUtils.isNotEmpty(metric.getDefaultAgg())) { - metricStr.append( - " AGGREGATE '" - + metric.getDefaultAgg().toUpperCase() - + "'"); - } - metricStr.append(">"); - metrics.add(metricStr.toString()); - }); + llmReq.getSchema().getMetrics().stream().forEach(metric -> { + StringBuilder metricStr = new StringBuilder(); + metricStr.append("<"); + metricStr.append(metric.getName()); + if (!CollectionUtils.isEmpty(metric.getAlias())) { + StringBuilder alias = new StringBuilder(); + metric.getAlias().stream().forEach(a -> alias.append(a + ",")); + metricStr.append(" ALIAS '" + alias + "'"); + } + if (StringUtils.isNotEmpty(metric.getDataFormatType())) { + String dataFormatType = metric.getDataFormatType(); + if (DataFormatTypeEnum.DECIMAL.getName().equalsIgnoreCase(dataFormatType) + || DataFormatTypeEnum.PERCENT.getName().equalsIgnoreCase(dataFormatType)) { + metricStr.append(" FORMAT '" + dataFormatType + "'"); + } + } + if (StringUtils.isNotEmpty(metric.getDescription())) { + metricStr.append(" COMMENT '" + metric.getDescription() + "'"); + } + if (StringUtils.isNotEmpty(metric.getDefaultAgg())) { + metricStr.append(" AGGREGATE '" + metric.getDefaultAgg().toUpperCase() + "'"); + } + metricStr.append(">"); + metrics.add(metricStr.toString()); + }); List dimensions = Lists.newArrayList(); - llmReq.getSchema().getDimensions().stream() - .forEach( - dimension -> { - StringBuilder dimensionStr = new StringBuilder(); - dimensionStr.append("<"); - dimensionStr.append(dimension.getName()); - if (!CollectionUtils.isEmpty(dimension.getAlias())) { - StringBuilder alias = new StringBuilder(); - dimension.getAlias().stream().forEach(a -> alias.append(a + ",")); - dimensionStr.append(" ALIAS '" + alias + "'"); - } - if (StringUtils.isNotEmpty(dimension.getTimeFormat())) { - dimensionStr.append(" FORMAT '" + dimension.getTimeFormat() + "'"); - } - if (StringUtils.isNotEmpty(dimension.getDescription())) { - dimensionStr.append( - " COMMENT '" + dimension.getDescription() + "'"); - } - dimensionStr.append(">"); - dimensions.add(dimensionStr.toString()); - }); + llmReq.getSchema().getDimensions().stream().forEach(dimension -> { + StringBuilder dimensionStr = new StringBuilder(); + dimensionStr.append("<"); + dimensionStr.append(dimension.getName()); + if (!CollectionUtils.isEmpty(dimension.getAlias())) { + StringBuilder alias = new StringBuilder(); + dimension.getAlias().stream().forEach(a -> alias.append(a + ",")); + dimensionStr.append(" ALIAS '" + alias + "'"); + } + if (StringUtils.isNotEmpty(dimension.getTimeFormat())) { + dimensionStr.append(" FORMAT '" + dimension.getTimeFormat() + "'"); + } + if (StringUtils.isNotEmpty(dimension.getDescription())) { + dimensionStr.append(" COMMENT '" + dimension.getDescription() + "'"); + } + dimensionStr.append(">"); + dimensions.add(dimensionStr.toString()); + }); List values = Lists.newArrayList(); - llmReq.getSchema().getValues().stream() - .forEach( - value -> { - StringBuilder valueStr = new StringBuilder(); - String fieldName = value.getFieldName(); - String fieldValue = value.getFieldValue(); - valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue)); - values.add(valueStr.toString()); - }); + llmReq.getSchema().getValues().stream().forEach(value -> { + StringBuilder valueStr = new StringBuilder(); + String fieldName = value.getFieldName(); + String fieldValue = value.getFieldValue(); + valueStr.append(String.format("<%s='%s'>", fieldName, fieldValue)); + values.add(valueStr.toString()); + }); String partitionTimeStr = ""; if (llmReq.getSchema().getPartitionTime() != null) { partitionTimeStr = - String.format( - "%s FORMAT '%s'", - llmReq.getSchema().getPartitionTime().getName(), + String.format("%s FORMAT '%s'", llmReq.getSchema().getPartitionTime().getName(), llmReq.getSchema().getPartitionTime().getTimeFormat()); } @@ -170,30 +154,19 @@ public class PromptHelper { String template = "DatabaseType=[%s], Table=[%s], PartitionTimeField=[%s], PrimaryKeyField=[%s], " + "Metrics=[%s], Dimensions=[%s], Values=[%s]"; - return String.format( - template, - databaseTypeStr, - tableStr, - partitionTimeStr, - primaryKeyStr, - String.join(",", metrics), - String.join(",", dimensions), - String.join(",", values)); + return String.format(template, databaseTypeStr, tableStr, partitionTimeStr, primaryKeyStr, + String.join(",", metrics), String.join(",", dimensions), String.join(",", values)); } private String buildTermStr(LLMReq llmReq) { List terms = llmReq.getTerms(); List termStr = Lists.newArrayList(); - terms.stream() - .forEach( - term -> { - StringBuilder termsDesc = new StringBuilder(); - String description = term.getDescription(); - termsDesc.append( - String.format( - "<%s COMMENT '%s'>", term.getName(), description)); - termStr.add(termsDesc.toString()); - }); + terms.stream().forEach(term -> { + StringBuilder termsDesc = new StringBuilder(); + String description = term.getDescription(); + termsDesc.append(String.format("<%s COMMENT '%s'>", term.getName(), description)); + termStr.add(termsDesc.toString()); + }); String ret = ""; if (termStr.size() > 0) { ret = String.join(",", termStr); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java index 3c147965f..2b2d85a0e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/ResponseHelper.java @@ -54,19 +54,13 @@ public class ResponseHelper { return Pair.of(inputMax, votePercentage); } - public static Map buildSqlRespMap( - List sqlExamples, Map sqlMap) { + public static Map buildSqlRespMap(List sqlExamples, + Map sqlMap) { if (sqlMap == null) { return new HashMap<>(); } return sqlMap.entrySet().stream() - .collect( - Collectors.toMap( - Map.Entry::getKey, - entry -> - LLMSqlResp.builder() - .sqlWeight(entry.getValue()) - .fewShots(sqlExamples) - .build())); + .collect(Collectors.toMap(Map.Entry::getKey, entry -> LLMSqlResp.builder() + .sqlWeight(entry.getValue()).fewShots(sqlExamples).build())); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java index 7f335be60..11fbd9b3f 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategy.java @@ -20,7 +20,8 @@ public abstract class SqlGenStrategy implements InitializingBean { protected static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - @Autowired protected PromptHelper promptHelper; + @Autowired + protected PromptHelper promptHelper; protected ChatLanguageModel getChatLanguageModel(ChatModelConfig modelConfig) { return ModelProvider.getChatModel(modelConfig); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategyFactory.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategyFactory.java index e4f5c867e..199164733 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategyFactory.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/llm/SqlGenStrategyFactory.java @@ -14,8 +14,8 @@ public class SqlGenStrategyFactory { return sqlGenStrategyMap.get(strategyType); } - public static void addSqlGenerationForFactory( - LLMReq.SqlGenType strategy, SqlGenStrategy sqlGenStrategy) { + public static void addSqlGenerationForFactory(LLMReq.SqlGenType strategy, + SqlGenStrategy sqlGenStrategy) { sqlGenStrategyMap.put(strategy, sqlGenStrategy); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java index e4efda24a..2d04dd1d7 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/AggregateTypeParser.java @@ -27,27 +27,20 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.DISTINC @Slf4j public class AggregateTypeParser implements SemanticParser { - private static final Map REGX_MAP = - Stream.of( - new AbstractMap.SimpleEntry<>( - AggregateTypeEnum.MAX, - Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")), - new AbstractMap.SimpleEntry<>( - AggregateTypeEnum.MIN, - Pattern.compile("(?i)(最小值|最小|min|最低|最少)")), - new AbstractMap.SimpleEntry<>( - AggregateTypeEnum.SUM, Pattern.compile("(?i)(汇总|总和|sum)")), - new AbstractMap.SimpleEntry<>( - AggregateTypeEnum.AVG, Pattern.compile("(?i)(平均值|日均|平均|avg)")), - new AbstractMap.SimpleEntry<>( - AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")), - new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")), - new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")), - new AbstractMap.SimpleEntry<>( - AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)"))) - .collect( - Collectors.toMap( - Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2)); + private static final Map REGX_MAP = Stream.of( + new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MAX, + Pattern.compile("(?i)(最大值|最大|max|峰值|最高|最多)")), + new AbstractMap.SimpleEntry<>(AggregateTypeEnum.MIN, + Pattern.compile("(?i)(最小值|最小|min|最低|最少)")), + new AbstractMap.SimpleEntry<>(AggregateTypeEnum.SUM, + Pattern.compile("(?i)(汇总|总和|sum)")), + new AbstractMap.SimpleEntry<>(AggregateTypeEnum.AVG, + Pattern.compile("(?i)(平均值|日均|平均|avg)")), + new AbstractMap.SimpleEntry<>(AggregateTypeEnum.TOPN, Pattern.compile("(?i)(top)")), + new AbstractMap.SimpleEntry<>(DISTINCT, Pattern.compile("(?i)(uv)")), + new AbstractMap.SimpleEntry<>(COUNT, Pattern.compile("(?i)(总数|pv)")), + new AbstractMap.SimpleEntry<>(AggregateTypeEnum.NONE, Pattern.compile("(?i)(明细)"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k2)); @Override public void parse(ChatQueryContext chatQueryContext) { @@ -63,8 +56,7 @@ public class AggregateTypeParser implements SemanticParser { if (StringUtils.isNotEmpty(aggregateConf.detectWord)) { detectWordLength = aggregateConf.detectWord.length(); } - semanticQuery - .getParseInfo() + semanticQuery.getParseInfo() .setScore(semanticQuery.getParseInfo().getScore() + detectWordLength); } } @@ -93,10 +85,8 @@ public class AggregateTypeParser implements SemanticParser { } AggregateTypeEnum type = - aggregateCount.entrySet().stream() - .max(Map.Entry.comparingByValue()) - .map(entry -> entry.getKey()) - .orElse(AggregateTypeEnum.NONE); + aggregateCount.entrySet().stream().max(Map.Entry.comparingByValue()) + .map(entry -> entry.getKey()).orElse(AggregateTypeEnum.NONE); String detectWord = aggregateWord.get(type); return new AggregateConf(type, detectWord); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java index 41b703cd4..79e22a17e 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/ContextInheritParser.java @@ -32,25 +32,18 @@ public class ContextInheritParser implements SemanticParser { private static final Map> MUTUAL_EXCLUSIVE_MAP = Stream.of( - new AbstractMap.SimpleEntry<>( - SchemaElementType.METRIC, - Arrays.asList(SchemaElementType.METRIC)), - new AbstractMap.SimpleEntry<>( - SchemaElementType.DIMENSION, - Arrays.asList( - SchemaElementType.DIMENSION, SchemaElementType.VALUE)), - new AbstractMap.SimpleEntry<>( - SchemaElementType.VALUE, - Arrays.asList( - SchemaElementType.VALUE, SchemaElementType.DIMENSION)), - new AbstractMap.SimpleEntry<>( - SchemaElementType.ENTITY, - Arrays.asList(SchemaElementType.ENTITY)), - new AbstractMap.SimpleEntry<>( - SchemaElementType.DATASET, - Arrays.asList(SchemaElementType.DATASET)), - new AbstractMap.SimpleEntry<>( - SchemaElementType.ID, Arrays.asList(SchemaElementType.ID))) + new AbstractMap.SimpleEntry<>(SchemaElementType.METRIC, + Arrays.asList(SchemaElementType.METRIC)), + new AbstractMap.SimpleEntry<>(SchemaElementType.DIMENSION, + Arrays.asList(SchemaElementType.DIMENSION, SchemaElementType.VALUE)), + new AbstractMap.SimpleEntry<>(SchemaElementType.VALUE, + Arrays.asList(SchemaElementType.VALUE, SchemaElementType.DIMENSION)), + new AbstractMap.SimpleEntry<>(SchemaElementType.ENTITY, + Arrays.asList(SchemaElementType.ENTITY)), + new AbstractMap.SimpleEntry<>(SchemaElementType.DATASET, + Arrays.asList(SchemaElementType.DATASET)), + new AbstractMap.SimpleEntry<>(SchemaElementType.ID, + Arrays.asList(SchemaElementType.ID))) .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); @Override @@ -67,13 +60,12 @@ public class ContextInheritParser implements SemanticParser { chatQueryContext.getMapInfo().getMatchedElements(dataSetId); List matchesToInherit = new ArrayList<>(); - for (SchemaElementMatch match : - chatQueryContext.getContextParseInfo().getElementMatches()) { + for (SchemaElementMatch match : chatQueryContext.getContextParseInfo() + .getElementMatches()) { SchemaElementType matchType = match.getElement().getType(); // mutual exclusive element types should not be inherited - RuleSemanticQuery ruleQuery = - QueryManager.getRuleQuery( - chatQueryContext.getContextParseInfo().getQueryMode()); + RuleSemanticQuery ruleQuery = QueryManager + .getRuleQuery(chatQueryContext.getContextParseInfo().getQueryMode()); if (!containsTypes(elementMatches, matchType, ruleQuery)) { match.setInherited(true); matchesToInherit.add(match); @@ -85,16 +77,16 @@ public class ContextInheritParser implements SemanticParser { RuleSemanticQuery.resolve(dataSetId, elementMatches, chatQueryContext); for (RuleSemanticQuery query : queries) { query.fillParseInfo(chatQueryContext); - if (existSameQuery( - query.getParseInfo().getDataSetId(), query.getQueryMode(), chatQueryContext)) { + if (existSameQuery(query.getParseInfo().getDataSetId(), query.getQueryMode(), + chatQueryContext)) { continue; } chatQueryContext.getCandidateQueries().add(query); } } - private boolean existSameQuery( - Long dataSetId, String queryMode, ChatQueryContext chatQueryContext) { + private boolean existSameQuery(Long dataSetId, String queryMode, + ChatQueryContext chatQueryContext) { for (SemanticQuery semanticQuery : chatQueryContext.getCandidateQueries()) { if (semanticQuery.getQueryMode().equals(queryMode) && semanticQuery.getParseInfo().getDataSetId().equals(dataSetId)) { @@ -104,33 +96,26 @@ public class ContextInheritParser implements SemanticParser { return false; } - private boolean containsTypes( - List matches, - SchemaElementType matchType, + private boolean containsTypes(List matches, SchemaElementType matchType, RuleSemanticQuery ruleQuery) { List types = MUTUAL_EXCLUSIVE_MAP.get(matchType); - return matches.stream() - .anyMatch( - m -> { - SchemaElementType type = m.getElement().getType(); - if (Objects.nonNull(ruleQuery) - && ruleQuery instanceof MetricSemanticQuery - && !(ruleQuery instanceof MetricIdQuery)) { - return types.contains(type); - } - return type.equals(matchType); - }); + return matches.stream().anyMatch(m -> { + SchemaElementType type = m.getElement().getType(); + if (Objects.nonNull(ruleQuery) && ruleQuery instanceof MetricSemanticQuery + && !(ruleQuery instanceof MetricIdQuery)) { + return types.contains(type); + } + return type.equals(matchType); + }); } protected boolean shouldInherit(ChatQueryContext chatQueryContext) { // if candidates only have MetricModel mode, count in context List metricModelQueries = chatQueryContext.getCandidateQueries().stream() - .filter( - query -> - query instanceof MetricModelQuery - || query instanceof DetailDimensionQuery) + .filter(query -> query instanceof MetricModelQuery + || query instanceof DetailDimensionQuery) .collect(Collectors.toList()); return metricModelQueries.size() == chatQueryContext.getCandidateQueries().size(); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java index 9bbf72da5..40113401d 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/RuleSqlParser.java @@ -17,9 +17,8 @@ import java.util.List; @Slf4j public class RuleSqlParser implements SemanticParser { - private static List auxiliaryParsers = - Arrays.asList( - new ContextInheritParser(), new TimeRangeParser(), new AggregateTypeParser()); + private static List auxiliaryParsers = Arrays.asList(new ContextInheritParser(), + new TimeRangeParser(), new AggregateTypeParser()); @Override public void parse(ChatQueryContext chatQueryContext) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java index f88170c2d..7a56aa5f1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/parser/rule/TimeRangeParser.java @@ -30,9 +30,8 @@ import java.util.regex.Pattern; @Slf4j public class TimeRangeParser implements SemanticParser { - private static final Pattern RECENT_PATTERN_CN = - Pattern.compile( - ".*(?(近|过去)((?\\d+)|(?[一二三四五六七八九十百千万亿]+))个?(?[天周月年])).*"); + private static final Pattern RECENT_PATTERN_CN = Pattern.compile( + ".*(?(近|过去)((?\\d+)|(?[一二三四五六七八九十百千万亿]+))个?(?[天周月年])).*"); private static final Pattern DATE_PATTERN_NUMBER = Pattern.compile("(\\d{8})"); private static final DateFormat DATE_FORMAT_NUMBER = new SimpleDateFormat("yyyyMMdd"); private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd"); @@ -70,8 +69,8 @@ public class TimeRangeParser implements SemanticParser { if (queryContext.containsPartitionDimensions(contextParseInfo.getDataSetId())) { contextParseInfo.setDateInfo(dateConf); } - contextParseInfo.setScore( - contextParseInfo.getScore() + dateConf.getDetectWord().length()); + contextParseInfo + .setScore(contextParseInfo.getScore() + dateConf.getDetectWord().length()); semanticQuery.setParseInfo(contextParseInfo); queryContext.getCandidateQueries().add(semanticQuery); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java index 2093c646c..a6ce039ed 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/BaseSemanticQuery.java @@ -52,8 +52,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { parseInfo.getSqlInfo().setCorrectedS2SQL(querySQLReq.getSql()); } - protected void convertBizNameToName( - DataSetSchema dataSetSchema, QueryStructReq queryStructReq) { + protected void convertBizNameToName(DataSetSchema dataSetSchema, + QueryStructReq queryStructReq) { Map bizNameToName = dataSetSchema.getBizNameToName(); bizNameToName.putAll(TimeDimensionEnum.getNameToNameMap()); @@ -76,8 +76,8 @@ public abstract class BaseSemanticQuery implements SemanticQuery, Serializable { } List dimensionFilters = queryStructReq.getDimensionFilters(); if (CollectionUtils.isNotEmpty(dimensionFilters)) { - dimensionFilters.forEach( - filter -> filter.setName(bizNameToName.get(filter.getBizName()))); + dimensionFilters + .forEach(filter -> filter.setName(bizNameToName.get(filter.getBizName()))); } List metricFilters = queryStructReq.getMetricFilters(); if (CollectionUtils.isNotEmpty(dimensionFilters)) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/LLMSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/LLMSemanticQuery.java index 52f8843a0..df8b48745 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/LLMSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/LLMSemanticQuery.java @@ -4,4 +4,5 @@ import com.tencent.supersonic.headless.chat.query.BaseSemanticQuery; import lombok.extern.slf4j.Slf4j; @Slf4j -public abstract class LLMSemanticQuery extends BaseSemanticQuery {} +public abstract class LLMSemanticQuery extends BaseSemanticQuery { +} diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java index 064e36ee4..cc8e05bda 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/llm/s2sql/LLMReq.java @@ -46,16 +46,12 @@ public class LLMReq { public List getFieldNameList() { List fieldNameList = new ArrayList<>(); if (CollectionUtils.isNotEmpty(metrics)) { - fieldNameList.addAll( - metrics.stream() - .map(metric -> metric.getName()) - .collect(Collectors.toList())); + fieldNameList.addAll(metrics.stream().map(metric -> metric.getName()) + .collect(Collectors.toList())); } if (CollectionUtils.isNotEmpty(dimensions)) { - fieldNameList.addAll( - dimensions.stream() - .map(dimension -> dimension.getName()) - .collect(Collectors.toList())); + fieldNameList.addAll(dimensions.stream().map(dimension -> dimension.getName()) + .collect(Collectors.toList())); } if (Objects.nonNull(partitionTime)) { fieldNameList.add(partitionTime.getName()); @@ -76,6 +72,7 @@ public class LLMReq { public enum SqlGenType { ONE_PASS_SELF_CONSISTENCY("1_pass_self_consistency"); + private String name; SqlGenType(String name) { diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatchOption.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatchOption.java index 9e5f9dd6e..4f4067cf8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatchOption.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatchOption.java @@ -9,10 +9,8 @@ public class QueryMatchOption { private RequireNumberType requireNumberType; private Integer requireNumber; - public static QueryMatchOption build( - OptionType schemaElementOption, - RequireNumberType requireNumberType, - Integer requireNumber) { + public static QueryMatchOption build(OptionType schemaElementOption, + RequireNumberType requireNumberType, Integer requireNumber) { QueryMatchOption queryMatchOption = new QueryMatchOption(); queryMatchOption.requireNumber = requireNumber; queryMatchOption.requireNumberType = requireNumberType; @@ -37,14 +35,10 @@ public class QueryMatchOption { } public enum RequireNumberType { - AT_MOST, - AT_LEAST, - EQUAL + AT_MOST, AT_LEAST, EQUAL } public enum OptionType { - REQUIRED, - OPTIONAL, - UNUSED + REQUIRED, OPTIONAL, UNUSED } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatcher.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatcher.java index bd498dc7c..f2cf6ddec 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatcher.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/QueryMatcher.java @@ -33,13 +33,10 @@ public class QueryMatcher { } } - public QueryMatcher addOption( - SchemaElementType type, - QueryMatchOption.OptionType option, - QueryMatchOption.RequireNumberType requireNumberType, - Integer requireNumber) { - elementOptionMap.put( - type, QueryMatchOption.build(option, requireNumberType, requireNumber)); + public QueryMatcher addOption(SchemaElementType type, QueryMatchOption.OptionType option, + QueryMatchOption.RequireNumberType requireNumberType, Integer requireNumber) { + elementOptionMap.put(type, + QueryMatchOption.build(option, requireNumberType, requireNumber)); return this; } @@ -55,8 +52,8 @@ public class QueryMatcher { for (SchemaElementMatch schemaElementMatch : candidateElementMatches) { SchemaElementType schemaElementType = schemaElementMatch.getElement().getType(); if (schemaElementTypeCount.containsKey(schemaElementType)) { - schemaElementTypeCount.put( - schemaElementType, schemaElementTypeCount.get(schemaElementType) + 1); + schemaElementTypeCount.put(schemaElementType, + schemaElementTypeCount.get(schemaElementType) + 1); } else { schemaElementTypeCount.put(schemaElementType, 1); } @@ -75,10 +72,8 @@ public class QueryMatcher { for (SchemaElementMatch elementMatch : candidateElementMatches) { QueryMatchOption elementOption = elementOptionMap.get(elementMatch.getElement().getType()); - if (Objects.nonNull(elementOption) - && !elementOption - .getSchemaElementOption() - .equals(QueryMatchOption.OptionType.UNUSED)) { + if (Objects.nonNull(elementOption) && !elementOption.getSchemaElementOption() + .equals(QueryMatchOption.OptionType.UNUSED)) { elementMatches.add(elementMatch); } } @@ -86,8 +81,7 @@ public class QueryMatcher { return elementMatches; } - private int getCount( - HashMap schemaElementTypeCount, + private int getCount(HashMap schemaElementTypeCount, SchemaElementType schemaElementType) { if (schemaElementTypeCount.containsKey(schemaElementType)) { return schemaElementTypeCount.get(schemaElementType); @@ -101,15 +95,13 @@ public class QueryMatcher { && count <= 0) { return false; } - if (queryMatchOption - .getRequireNumberType() - .equals(QueryMatchOption.RequireNumberType.AT_LEAST) + if (queryMatchOption.getRequireNumberType() + .equals(QueryMatchOption.RequireNumberType.AT_LEAST) && count < queryMatchOption.getRequireNumber()) { return false; } - if (queryMatchOption - .getRequireNumberType() - .equals(QueryMatchOption.RequireNumberType.AT_MOST) + if (queryMatchOption.getRequireNumberType() + .equals(QueryMatchOption.RequireNumberType.AT_MOST) && count > queryMatchOption.getRequireNumber()) { return false; } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java index 94fea6e2b..95fa69da5 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/RuleSemanticQuery.java @@ -40,8 +40,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { QueryManager.register(this); } - public List match( - List candidateElementMatches, ChatQueryContext queryCtx) { + public List match(List candidateElementMatches, + ChatQueryContext queryCtx) { return queryMatcher.match(candidateElementMatches); } @@ -67,17 +67,16 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { return chatQueryContext.containsPartitionDimensions(dataSetId); } - private void fillDateConfByInherited( - SemanticParseInfo queryParseInfo, ChatQueryContext chatQueryContext) { + private void fillDateConfByInherited(SemanticParseInfo queryParseInfo, + ChatQueryContext chatQueryContext) { SemanticParseInfo contextParseInfo = chatQueryContext.getContextParseInfo(); - if (queryParseInfo.getDateInfo() != null - || contextParseInfo.getDateInfo() == null + if (queryParseInfo.getDateInfo() != null || contextParseInfo.getDateInfo() == null || needFillDateConf(chatQueryContext)) { return; } if ((QueryManager.isDetailQuery(queryParseInfo.getQueryMode()) - && QueryManager.isDetailQuery(contextParseInfo.getQueryMode())) + && QueryManager.isDetailQuery(contextParseInfo.getQueryMode())) || (QueryManager.isMetricQuery(queryParseInfo.getQueryMode()) && QueryManager.isMetricQuery(contextParseInfo.getQueryMode()))) { // inherit date info from context @@ -107,10 +106,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { private void fillSchemaElement(SemanticParseInfo parseInfo, SemanticSchema semanticSchema) { Set dataSetIds = - parseInfo.getElementMatches().stream() - .map(SchemaElementMatch::getElement) - .map(SchemaElement::getDataSetId) - .collect(Collectors.toSet()); + parseInfo.getElementMatches().stream().map(SchemaElementMatch::getElement) + .map(SchemaElement::getDataSetId).collect(Collectors.toSet()); Long dataSetId = dataSetIds.iterator().next(); parseInfo.setDataSet(semanticSchema.getDataSet(dataSetId)); parseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId)); @@ -128,8 +125,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { if (id2Values.containsKey(element.getId())) { id2Values.get(element.getId()).add(schemaMatch); } else { - id2Values.put( - element.getId(), new ArrayList<>(Arrays.asList(schemaMatch))); + id2Values.put(element.getId(), + new ArrayList<>(Arrays.asList(schemaMatch))); } } break; @@ -140,8 +137,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { if (dim2Values.containsKey(element.getId())) { dim2Values.get(element.getId()).add(schemaMatch); } else { - dim2Values.put( - element.getId(), new ArrayList<>(Arrays.asList(schemaMatch))); + dim2Values.put(element.getId(), + new ArrayList<>(Arrays.asList(schemaMatch))); } } break; @@ -161,11 +158,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { addToFilters(dim2Values, parseInfo, semanticSchema, SchemaElementType.DIMENSION); } - private void addToFilters( - Map> id2Values, - SemanticParseInfo parseInfo, - SemanticSchema semanticSchema, - SchemaElementType entity) { + private void addToFilters(Map> id2Values, + SemanticParseInfo parseInfo, SemanticSchema semanticSchema, SchemaElementType entity) { if (id2Values == null || id2Values.isEmpty()) { return; } @@ -206,8 +200,7 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { public SemanticQueryReq multiStructExecute() { String queryMode = parseInfo.getQueryMode(); - if (parseInfo.getDataSetId() != null - || StringUtils.isEmpty(queryMode) + if (parseInfo.getDataSetId() != null || StringUtils.isEmpty(queryMode) || !QueryManager.containsRuleQuery(queryMode)) { // reach here some error may happen log.error("not find QueryMode"); @@ -222,10 +215,8 @@ public abstract class RuleSemanticQuery extends BaseSemanticQuery { this.parseInfo = parseInfo; } - public static List resolve( - Long dataSetId, - List candidateElementMatches, - ChatQueryContext chatQueryContext) { + public static List resolve(Long dataSetId, + List candidateElementMatches, ChatQueryContext chatQueryContext) { List matchedQueries = new ArrayList<>(); for (RuleSemanticQuery semanticQuery : QueryManager.getRuleQueries()) { List matches = diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java index e1f718304..f127d40d8 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailListQuery.java @@ -20,8 +20,8 @@ public abstract class DetailListQuery extends DetailSemanticQuery { this.addEntityDetailAndOrderByMetric(chatQueryContext, parseInfo); } - private void addEntityDetailAndOrderByMetric( - ChatQueryContext chatQueryContext, SemanticParseInfo parseInfo) { + private void addEntityDetailAndOrderByMetric(ChatQueryContext chatQueryContext, + SemanticParseInfo parseInfo) { Long dataSetId = parseInfo.getDataSetId(); if (Objects.isNull(dataSetId) || dataSetId <= 0L) { return; @@ -38,35 +38,23 @@ public abstract class DetailListQuery extends DetailSemanticQuery { && detailTypeDefaultConfig.getDefaultDisplayInfo() != null) { if (CollectionUtils.isNotEmpty( detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds())) { - metrics = - detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds().stream() - .map( - id -> { - SchemaElement metric = - dataSetSchema.getElement( - SchemaElementType.METRIC, id); - if (metric != null) { - orders.add( - new Order( - metric.getBizName(), - Constants.DESC_UPPER)); - } - return metric; - }) - .filter(Objects::nonNull) - .collect(Collectors.toSet()); + metrics = detailTypeDefaultConfig.getDefaultDisplayInfo().getMetricIds() + .stream().map(id -> { + SchemaElement metric = + dataSetSchema.getElement(SchemaElementType.METRIC, id); + if (metric != null) { + orders.add( + new Order(metric.getBizName(), Constants.DESC_UPPER)); + } + return metric; + }).filter(Objects::nonNull).collect(Collectors.toSet()); } if (CollectionUtils.isNotEmpty( detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds())) { - dimensions = - detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds() - .stream() - .map( - id -> - dataSetSchema.getElement( - SchemaElementType.DIMENSION, id)) - .filter(Objects::nonNull) - .collect(Collectors.toSet()); + dimensions = detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds() + .stream() + .map(id -> dataSetSchema.getElement(SchemaElementType.DIMENSION, id)) + .filter(Objects::nonNull).collect(Collectors.toSet()); } } parseInfo.setDimensions(dimensions); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java index b4b5fa92f..922ad43a1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/detail/DetailSemanticQuery.java @@ -23,8 +23,8 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery { } @Override - public List match( - List candidateElementMatches, ChatQueryContext queryCtx) { + public List match(List candidateElementMatches, + ChatQueryContext queryCtx) { return super.match(candidateElementMatches, queryCtx); } @@ -43,8 +43,7 @@ public abstract class DetailSemanticQuery extends RuleSemanticQuery { DataSetSchema dataSetSchema = dataSetSchemaMap.get(parseInfo.getDataSetId()); TimeDefaultConfig timeDefaultConfig = dataSetSchema.getTagTypeTimeDefaultConfig(); - if (Objects.nonNull(timeDefaultConfig) - && Objects.nonNull(timeDefaultConfig.getUnit()) + if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit()) && timeDefaultConfig.getUnit() != -1) { DateConf dateInfo = new DateConf(); int unit = timeDefaultConfig.getUnit(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricFilterQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricFilterQuery.java index d8a4e6d83..f5fde3c1a 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricFilterQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricFilterQuery.java @@ -71,19 +71,15 @@ public class MetricFilterQuery extends MetricSemanticQuery { log.debug("addDimension before [{}]", queryStructReq.getGroups()); List filters = new ArrayList<>(queryStructReq.getDimensionFilters()); if (onlyOperateInFilter) { - filters = - filters.stream() - .filter( - filter -> - filter.getOperator().equals(FilterOperatorEnum.IN)) - .collect(Collectors.toList()); + filters = filters.stream() + .filter(filter -> filter.getOperator().equals(FilterOperatorEnum.IN)) + .collect(Collectors.toList()); } - filters.forEach( - d -> { - if (!dimensions.contains(d.getBizName())) { - dimensions.add(d.getBizName()); - } - }); + filters.forEach(d -> { + if (!dimensions.contains(d.getBizName())) { + dimensions.add(d.getBizName()); + } + }); queryStructReq.setGroups(dimensions); log.debug("addDimension after [{}]", queryStructReq.getGroups()); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricIdQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricIdQuery.java index 45cce5e81..a42b14b0b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricIdQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricIdQuery.java @@ -46,8 +46,7 @@ public class MetricIdQuery extends MetricSemanticQuery { protected boolean isMultiStructQuery() { Set filterBizName = new HashSet<>(); - parseInfo.getDimensionFilters().stream() - .filter(filter -> filter.getElementID() != null) + parseInfo.getDimensionFilters().stream().filter(filter -> filter.getElementID() != null) .forEach(filter -> filterBizName.add(filter.getBizName())); return FilterType.UNION.equals(parseInfo.getFilterType()) && filterBizName.size() > 1; } @@ -74,19 +73,15 @@ public class MetricIdQuery extends MetricSemanticQuery { log.info("addDimension before [{}]", queryStructReq.getGroups()); List filters = new ArrayList<>(queryStructReq.getDimensionFilters()); if (onlyOperateInFilter) { - filters = - filters.stream() - .filter( - filter -> - filter.getOperator().equals(FilterOperatorEnum.IN)) - .collect(Collectors.toList()); + filters = filters.stream() + .filter(filter -> filter.getOperator().equals(FilterOperatorEnum.IN)) + .collect(Collectors.toList()); } - filters.forEach( - d -> { - if (!dimensions.contains(d.getBizName())) { - dimensions.add(d.getBizName()); - } - }); + filters.forEach(d -> { + if (!dimensions.contains(d.getBizName())) { + dimensions.add(d.getBizName()); + } + }); queryStructReq.setGroups(dimensions); log.info("addDimension after [{}]", queryStructReq.getGroups()); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java index 29d8b9acc..cbd482cf1 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricSemanticQuery.java @@ -26,8 +26,8 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery { } @Override - public List match( - List candidateElementMatches, ChatQueryContext queryCtx) { + public List match(List candidateElementMatches, + ChatQueryContext queryCtx) { return super.match(candidateElementMatches, queryCtx); } @@ -42,16 +42,12 @@ public abstract class MetricSemanticQuery extends RuleSemanticQuery { if (parseInfo.getDateInfo() != null || !needFillDateConf(chatQueryContext)) { return; } - DataSetSchema dataSetSchema = - chatQueryContext - .getSemanticSchema() - .getDataSetSchemaMap() - .get(parseInfo.getDataSetId()); + DataSetSchema dataSetSchema = chatQueryContext.getSemanticSchema().getDataSetSchemaMap() + .get(parseInfo.getDataSetId()); TimeDefaultConfig timeDefaultConfig = dataSetSchema.getMetricTypeTimeDefaultConfig(); DateConf dateInfo = new DateConf(); // 加上时间!=-1 判断 - if (Objects.nonNull(timeDefaultConfig) - && Objects.nonNull(timeDefaultConfig.getUnit()) + if (Objects.nonNull(timeDefaultConfig) && Objects.nonNull(timeDefaultConfig.getUnit()) && timeDefaultConfig.getUnit() != -1) { int unit = timeDefaultConfig.getUnit(); String startDate = LocalDate.now().minusDays(unit).toString(); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java index e600a5a10..bb02442e2 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/query/rule/metric/MetricTopNQuery.java @@ -33,8 +33,8 @@ public class MetricTopNQuery extends MetricSemanticQuery { } @Override - public List match( - List candidateElementMatches, ChatQueryContext queryCtx) { + public List match(List candidateElementMatches, + ChatQueryContext queryCtx) { Matcher matcher = INTENT_PATTERN.matcher(queryCtx.getQueryText()); if (matcher.matches()) { return super.match(candidateElementMatches, queryCtx); diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java index fe5b32790..92f0c9717 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/ComponentFactory.java @@ -26,15 +26,13 @@ public class ComponentFactory { } private static List init(Class factoryType, List list) { - list.addAll( - SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader())); + list.addAll(SpringFactoriesLoader.loadFactories(factoryType, + Thread.currentThread().getContextClassLoader())); return list; } private static T init(Class factoryType) { - return SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader()) - .get(0); + return SpringFactoriesLoader + .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java index 4ac8c360c..70d1c276b 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/EditDistanceUtils.java @@ -20,8 +20,7 @@ public class EditDistanceUtils { public static double getSimilarity(String detectSegment, String matchName) { String detectSegmentLower = detectSegment == null ? null : detectSegment.toLowerCase(); String matchNameLower = matchName == null ? null : matchName.toLowerCase(); - return 1 - - (double) EditDistance.compute(detectSegmentLower, matchNameLower) - / Math.max(matchName.length(), detectSegment.length()); + return 1 - (double) EditDistance.compute(detectSegmentLower, matchNameLower) + / Math.max(matchName.length(), detectSegment.length()); } } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java index b5c1baf43..3b7443335 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParser.java @@ -13,10 +13,8 @@ public class QueryFilterParser { public static String parse(QueryFilters queryFilters) { try { - List conditions = - queryFilters.getFilters().stream() - .map(QueryFilterParser::parseFilter) - .collect(Collectors.toList()); + List conditions = queryFilters.getFilters().stream() + .map(QueryFilterParser::parseFilter).collect(Collectors.toList()); return String.join(" AND ", conditions); } catch (Exception e) { log.error("", e); @@ -36,10 +34,7 @@ public class QueryFilterParser { case BETWEEN: if (value instanceof List && ((List) value).size() == 2) { List values = (List) value; - return column - + " BETWEEN " - + formatValue(values.get(0)) - + " AND " + return column + " BETWEEN " + formatValue(values.get(0)) + " AND " + formatValue(values.get(1)); } throw new IllegalArgumentException( @@ -58,8 +53,8 @@ public class QueryFilterParser { private static String parseList(Object value) { if (value instanceof List) { - return ((List) value) - .stream().map(QueryFilterParser::formatValue).collect(Collectors.joining(", ")); + return ((List) value).stream().map(QueryFilterParser::formatValue) + .collect(Collectors.joining(", ")); } throw new IllegalArgumentException("IN and NOT IN operators require a list of values"); } diff --git a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java index 495bc113c..0289245a4 100644 --- a/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java +++ b/headless/chat/src/main/java/com/tencent/supersonic/headless/chat/utils/QueryReqBuilder.java @@ -46,15 +46,10 @@ public class QueryReqBuilder { List dimensionFilters = getFilters(parseInfo.getDimensionFilters()); queryStructReq.setDimensionFilters(dimensionFilters); - List metricFilters = - parseInfo.getMetricFilters().stream() - .map( - chatFilter -> - new Filter( - chatFilter.getBizName(), - chatFilter.getOperator(), - chatFilter.getValue())) - .collect(Collectors.toList()); + List metricFilters = parseInfo + .getMetricFilters().stream().map(chatFilter -> new Filter(chatFilter.getBizName(), + chatFilter.getOperator(), chatFilter.getValue())) + .collect(Collectors.toList()); queryStructReq.setMetricFilters(metricFilters); addDateDimension(parseInfo); @@ -62,10 +57,8 @@ public class QueryReqBuilder { if (isDateFieldAlreadyPresent(parseInfo, getDateField(parseInfo.getDateInfo()))) { parseInfo.getDimensions().removeIf(schemaElement -> schemaElement.isPartitionTime()); } - queryStructReq.setGroups( - parseInfo.getDimensions().stream() - .map(SchemaElement::getBizName) - .collect(Collectors.toList())); + queryStructReq.setGroups(parseInfo.getDimensions().stream().map(SchemaElement::getBizName) + .collect(Collectors.toList())); queryStructReq.setLimit(parseInfo.getLimit()); // only one metric is queried at once Set metrics = parseInfo.getMetrics(); @@ -73,8 +66,8 @@ public class QueryReqBuilder { SchemaElement metricElement = parseInfo.getMetrics().iterator().next(); Set order = getOrder(parseInfo.getOrders(), parseInfo.getAggType(), metricElement); - queryStructReq.setAggregators( - getAggregatorByMetric(parseInfo.getAggType(), metricElement)); + queryStructReq + .setAggregators(getAggregatorByMetric(parseInfo.getAggType(), metricElement)); queryStructReq.setOrders(new ArrayList<>(order)); } @@ -87,12 +80,8 @@ public class QueryReqBuilder { List dimensionFilters = queryFilters.stream() .filter(chatFilter -> StringUtils.isNotEmpty(chatFilter.getBizName())) - .map( - chatFilter -> - new Filter( - chatFilter.getBizName(), - chatFilter.getOperator(), - chatFilter.getValue())) + .map(chatFilter -> new Filter(chatFilter.getBizName(), + chatFilter.getOperator(), chatFilter.getValue())) .collect(Collectors.toList()); return dimensionFilters; } @@ -149,21 +138,20 @@ public class QueryReqBuilder { return querySQLReq; } - private static List getAggregatorByMetric( - AggregateTypeEnum aggregateType, SchemaElement metric) { + private static List getAggregatorByMetric(AggregateTypeEnum aggregateType, + SchemaElement metric) { if (metric == null) { return Collections.emptyList(); } String agg = determineAggregator(aggregateType, metric); - return Collections.singletonList( - new Aggregator(metric.getBizName(), AggOperatorEnum.of(agg))); + return Collections + .singletonList(new Aggregator(metric.getBizName(), AggOperatorEnum.of(agg))); } - private static String determineAggregator( - AggregateTypeEnum aggregateType, SchemaElement metric) { - if (aggregateType == null - || aggregateType.equals(AggregateTypeEnum.NONE) + private static String determineAggregator(AggregateTypeEnum aggregateType, + SchemaElement metric) { + if (aggregateType == null || aggregateType.equals(AggregateTypeEnum.NONE) || AggOperatorEnum.COUNT_DISTINCT.name().equalsIgnoreCase(metric.getDefaultAgg())) { return StringUtils.defaultIfBlank(metric.getDefaultAgg(), ""); } @@ -199,28 +187,24 @@ public class QueryReqBuilder { && !CollectionUtils.isEmpty(parseInfo.getDimensions()); } - private static boolean isDateFieldAlreadyPresent( - SemanticParseInfo parseInfo, String dateField) { + private static boolean isDateFieldAlreadyPresent(SemanticParseInfo parseInfo, + String dateField) { return parseInfo.getDimensions().stream() .anyMatch(dimension -> dimension.getBizName().equalsIgnoreCase(dateField)); } private static void addDimension(SemanticParseInfo parseInfo, SchemaElement dimension) { - List timeDimensions = - Arrays.asList( - TimeDimensionEnum.DAY.getName(), - TimeDimensionEnum.WEEK.getName(), - TimeDimensionEnum.MONTH.getName()); - Set dimensions = - parseInfo.getDimensions().stream() - .filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase())) - .collect(Collectors.toSet()); + List timeDimensions = Arrays.asList(TimeDimensionEnum.DAY.getName(), + TimeDimensionEnum.WEEK.getName(), TimeDimensionEnum.MONTH.getName()); + Set dimensions = parseInfo.getDimensions().stream() + .filter(d -> !timeDimensions.contains(d.getBizName().toLowerCase())) + .collect(Collectors.toSet()); dimensions.add(dimension); parseInfo.setDimensions(dimensions); } - public static Set getOrder( - Set existingOrders, AggregateTypeEnum aggregator, SchemaElement metric) { + public static Set getOrder(Set existingOrders, AggregateTypeEnum aggregator, + SchemaElement metric) { if (existingOrders != null && !existingOrders.isEmpty()) { return existingOrders; } @@ -230,8 +214,7 @@ public class QueryReqBuilder { } Set orders = new LinkedHashSet<>(); - if (aggregator == AggregateTypeEnum.TOPN - || aggregator == AggregateTypeEnum.MAX + if (aggregator == AggregateTypeEnum.TOPN || aggregator == AggregateTypeEnum.MAX || aggregator == AggregateTypeEnum.MIN) { Order order = new Order(); order.setColumn(metric.getBizName()); @@ -256,8 +239,8 @@ public class QueryReqBuilder { return dateField; } - public static QueryStructReq buildStructRatioReq( - SemanticParseInfo parseInfo, SchemaElement metric, AggOperatorEnum aggOperatorEnum) { + public static QueryStructReq buildStructRatioReq(SemanticParseInfo parseInfo, + SchemaElement metric, AggOperatorEnum aggOperatorEnum) { QueryStructReq queryStructReq = buildStructReq(parseInfo); queryStructReq.setQueryType(QueryType.AGGREGATE); queryStructReq.setOrders(new ArrayList<>()); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java index 311585b4e..e8f730c16 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/AggCorrectorTest.java @@ -27,10 +27,9 @@ class AggCorrectorTest { dataSet.setDataSetId(dataSetId); semanticParseInfo.setDataSet(dataSet); SqlInfo sqlInfo = new SqlInfo(); - String sql = - "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND" - + " datediff('day', 数据日期, '2024-06-04') <= 7" - + " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1"; + String sql = "SELECT 用户, 访问次数 FROM 超音数数据集 WHERE 部门 = 'sales' AND" + + " datediff('day', 数据日期, '2024-06-04') <= 7" + + " GROUP BY 用户 ORDER BY SUM(访问次数) DESC LIMIT 1"; sqlInfo.setParsedS2SQL(sql); sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java index e2a33c3c3..1ec1de163 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SchemaCorrectorTest.java @@ -24,26 +24,20 @@ import java.util.Set; @Disabled class SchemaCorrectorTest { - private String json = - "{\n" - + " \"dataSetId\": 1,\n" - + " \"llmReq\": {\n" - + " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n" - + " \"schema\": {\n" - + " \"dataSetName\": \"歌曲\",\n" - + " \"fieldNameList\": [\n" - + " \"商务组\",\n" - + " \"歌曲名\",\n" - + " \"播放量\",\n" - + " \"播放份额\",\n" - + " \"数据日期\"\n" - + " ]\n" - + " },\n" - + " \"currentDate\": \"2024-02-24\",\n" - + " \"sqlGenType\": \"1_pass_self_consistency\"\n" - + " },\n" - + " \"request\": null\n" - + "}"; + private String json = "{\n" + " \"dataSetId\": 1,\n" + " \"llmReq\": {\n" + + " \"queryText\": \"xxx2024年播放量最高的十首歌\",\n" + + " \"schema\": {\n" + + " \"dataSetName\": \"歌曲\",\n" + + " \"fieldNameList\": [\n" + + " \"商务组\",\n" + + " \"歌曲名\",\n" + + " \"播放量\",\n" + + " \"播放份额\",\n" + + " \"数据日期\"\n" + + " ]\n" + " },\n" + + " \"currentDate\": \"2024-02-24\",\n" + + " \"sqlGenType\": \"1_pass_self_consistency\"\n" + + " },\n" + " \"request\": null\n" + "}"; @Test void doCorrect() throws JsonProcessingException { @@ -52,9 +46,8 @@ class SchemaCorrectorTest { ObjectMapper objectMapper = new ObjectMapper(); ParseResult parseResult = objectMapper.readValue(json, ParseResult.class); - String sql = - "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' " - + "and 商务组 = 'xxx' order by 播放量 desc limit 10"; + String sql = "select 歌曲名 from 歌曲 where 发行日期 >= '2024-01-01' " + + "and 商务组 = 'xxx' order by 播放量 desc limit 10"; SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SqlInfo sqlInfo = new SqlInfo(); sqlInfo.setParsedS2SQL(sql); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java index 92463a00f..7bdc961ee 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/SelectCorrectorTest.java @@ -42,8 +42,7 @@ class SelectCorrectorTest { sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); corrector.correct(chatQueryContext, semanticParseInfo); - Assert.assertEquals( - "SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'", + Assert.assertEquals("SELECT 粉丝数, 国籍, 艺人名, 性别 FROM 艺人库 WHERE 艺人名 = '周杰伦'", semanticParseInfo.getSqlInfo().getCorrectedS2SQL()); } diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java index ecccd9ced..99e8043ef 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/TimeCorrectorTest.java @@ -17,9 +17,8 @@ class TimeCorrectorTest { SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SqlInfo sqlInfo = new SqlInfo(); // 1.数据日期 <= - String sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); corrector.doCorrect(chatQueryContext, semanticParseInfo); @@ -30,9 +29,8 @@ class TimeCorrectorTest { sqlInfo.getCorrectedS2SQL()); // 2.数据日期 < - sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 < '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); @@ -42,9 +40,8 @@ class TimeCorrectorTest { sqlInfo.getCorrectedS2SQL()); // 3.数据日期 >= - sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 >= '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); @@ -54,9 +51,8 @@ class TimeCorrectorTest { sqlInfo.getCorrectedS2SQL()); // 4.数据日期 > - sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; + sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 > '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); @@ -70,14 +66,12 @@ class TimeCorrectorTest { sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); - Assert.assertEquals( - "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1", + Assert.assertEquals("SELECT 维度1, SUM(播放量) FROM 数据库 WHERE 歌手名 = '张三' GROUP BY 维度1", sqlInfo.getCorrectedS2SQL()); // 6. 数据日期-月 <= - sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1"; + sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE 歌手名 = '张三' AND 数据日期_月 <= '2024-01' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); @@ -87,9 +81,8 @@ class TimeCorrectorTest { sqlInfo.getCorrectedS2SQL()); // 7. 数据日期-月 > - sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1"; + sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE 歌手名 = '张三' AND 数据日期_月 > '2024-01' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); corrector.doCorrect(chatQueryContext, semanticParseInfo); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java index ab8593f7a..cabd5677e 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/corrector/WhereCorrectorTest.java @@ -16,9 +16,8 @@ class WhereCorrectorTest { void addQueryFilter() { SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); SqlInfo sqlInfo = new SqlInfo(); - String sql = - "SELECT 维度1, SUM(播放量) FROM 数据库 " - + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; + String sql = "SELECT 维度1, SUM(播放量) FROM 数据库 " + + "WHERE (歌手名 = '张三') AND 数据日期 <= '2023-11-17' GROUP BY 维度1"; sqlInfo.setCorrectedS2SQL(sql); semanticParseInfo.setSqlInfo(sqlInfo); @@ -56,8 +55,7 @@ class WhereCorrectorTest { String correctS2SQL = semanticParseInfo.getSqlInfo().getCorrectedS2SQL(); - Assert.assertEquals( - correctS2SQL, + Assert.assertEquals(correctS2SQL, "SELECT 维度1, SUM(播放量) FROM 数据库 WHERE " + "(歌手名 = '张三') AND 数据日期 <= '2023-11-17' AND age > 30 AND " + "name LIKE 'John%' AND id IN (1, 2, 3, 4) AND status GROUP BY 维度1"); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java index 854fa4471..bbc939339 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/parser/HeuristicDataSetResolverTest.java @@ -25,49 +25,17 @@ public class HeuristicDataSetResolverTest { Map> dataSet2Matches = chatQueryContext.getMapInfo().getDataSetElementMatches(); List matches = Lists.newArrayList(); - matches.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(1L) - .name("超音数") - .type(SchemaElementType.DATASET) - .build()) - .similarity(1) - .build()); - matches.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(1L) - .name("访问次数") - .type(SchemaElementType.METRIC) - .build()) - .similarity(0.5) - .build()); + matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L) + .name("超音数").type(SchemaElementType.DATASET).build()).similarity(1).build()); + matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L) + .name("访问次数").type(SchemaElementType.METRIC).build()).similarity(0.5).build()); dataSet2Matches.put(1L, matches); List matches2 = Lists.newArrayList(); - matches2.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(2L) - .name("访问用户数") - .type(SchemaElementType.METRIC) - .build()) - .similarity(1) - .build()); - matches2.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(2L) - .name("用户") - .type(SchemaElementType.DIMENSION) - .build()) - .similarity(1) - .build()); + matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L) + .name("访问用户数").type(SchemaElementType.METRIC).build()).similarity(1).build()); + matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L) + .name("用户").type(SchemaElementType.DIMENSION).build()).similarity(1).build()); dataSet2Matches.put(2L, matches2); Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets); @@ -81,39 +49,15 @@ public class HeuristicDataSetResolverTest { Map> dataSet2Matches = chatQueryContext.getMapInfo().getDataSetElementMatches(); List matches = Lists.newArrayList(); - matches.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(1L) - .name("访问次数") - .type(SchemaElementType.METRIC) - .build()) - .similarity(1) - .build()); + matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L) + .name("访问次数").type(SchemaElementType.METRIC).build()).similarity(1).build()); dataSet2Matches.put(1L, matches); List matches2 = Lists.newArrayList(); - matches2.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(2L) - .name("访问用户数") - .type(SchemaElementType.METRIC) - .build()) - .similarity(0.6) - .build()); - matches2.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(2L) - .name("用户") - .type(SchemaElementType.DIMENSION) - .build()) - .similarity(1) - .build()); + matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L) + .name("访问用户数").type(SchemaElementType.METRIC).build()).similarity(0.6).build()); + matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L) + .name("用户").type(SchemaElementType.DIMENSION).build()).similarity(1).build()); dataSet2Matches.put(2L, matches2); Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets); @@ -127,49 +71,17 @@ public class HeuristicDataSetResolverTest { Map> dataSet2Matches = chatQueryContext.getMapInfo().getDataSetElementMatches(); List matches = Lists.newArrayList(); - matches.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(1L) - .name("访问次数") - .type(SchemaElementType.METRIC) - .build()) - .similarity(0.8) - .build()); - matches.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(1L) - .name("部门") - .type(SchemaElementType.METRIC) - .build()) - .similarity(0.7) - .build()); + matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L) + .name("访问次数").type(SchemaElementType.METRIC).build()).similarity(0.8).build()); + matches.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(1L) + .name("部门").type(SchemaElementType.METRIC).build()).similarity(0.7).build()); dataSet2Matches.put(1L, matches); List matches2 = Lists.newArrayList(); - matches2.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(2L) - .name("访问用户数") - .type(SchemaElementType.METRIC) - .build()) - .similarity(0.8) - .build()); - matches2.add( - SchemaElementMatch.builder() - .element( - SchemaElement.builder() - .dataSetId(2L) - .name("用户") - .type(SchemaElementType.DIMENSION) - .build()) - .similarity(1) - .build()); + matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L) + .name("访问用户数").type(SchemaElementType.METRIC).build()).similarity(0.8).build()); + matches2.add(SchemaElementMatch.builder().element(SchemaElement.builder().dataSetId(2L) + .name("用户").type(SchemaElementType.DIMENSION).build()).similarity(1).build()); dataSet2Matches.put(2L, matches2); Long resolvedDataset = resolver.resolve(chatQueryContext, dataSets); diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMSqlParserTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMSqlParserTest.java index ffda22320..89cb685ae 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMSqlParserTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/s2sql/LLMSqlParserTest.java @@ -26,13 +26,8 @@ class LLMSqlParserTest { value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生")); schemaValueMaps.add(value1); - SchemaElement schemaElement = - SchemaElement.builder() - .bizName("singer_name") - .name("歌手名") - .dataSetId(2L) - .schemaValueMaps(schemaValueMaps) - .build(); + SchemaElement schemaElement = SchemaElement.builder().bizName("singer_name").name("歌手名") + .dataSetId(2L).schemaValueMaps(schemaValueMaps).build(); dimensions.add(schemaElement); SchemaElement schemaElement2 = diff --git a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java index 5e4c0aee6..e5c23b8ce 100644 --- a/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java +++ b/headless/chat/src/test/java/com/tencent/supersonic/headless/chat/utils/QueryFilterParserTest.java @@ -40,9 +40,7 @@ class QueryFilterParserTest { String parse = QueryFilterParser.parse(queryFilters); - Assert.assertEquals( - parse, - "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)" - + " AND status NOT_IN ('inactive', 'deleted')"); + Assert.assertEquals(parse, "age > 30 AND name LIKE 'John%' AND id IN (1, 2, 3, 4)" + + " AND status NOT_IN ('inactive', 'deleted')"); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java index 77a75fec7..fd229db45 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/BaseDbAdaptor.java @@ -72,11 +72,8 @@ public abstract class BaseDbAdaptor implements DbAdaptor { } protected DatabaseMetaData getDatabaseMetaData(ConnectInfo connectionInfo) throws SQLException { - Connection connection = - DriverManager.getConnection( - connectionInfo.getUrl(), - connectionInfo.getUserName(), - connectionInfo.getPassword()); + Connection connection = DriverManager.getConnection(connectionInfo.getUrl(), + connectionInfo.getUserName(), connectionInfo.getPassword()); return connection.getMetaData(); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/ClickHouseAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/ClickHouseAdaptor.java index 6249caf0b..0b9dfd68e 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/ClickHouseAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/ClickHouseAdaptor.java @@ -13,11 +13,11 @@ public class ClickHouseAdaptor extends BaseDbAdaptor { public String getDateFormat(String dateType, String dateFormat, String column) { if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT_INT)) { if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) { - return "toYYYYMM(toDate(parseDateTimeBestEffort(toString(%s))))" - .replace("%s", column); + return "toYYYYMM(toDate(parseDateTimeBestEffort(toString(%s))))".replace("%s", + column); } else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) { - return "toMonday(toDate(parseDateTimeBestEffort(toString(%s))))" - .replace("%s", column); + return "toMonday(toDate(parseDateTimeBestEffort(toString(%s))))".replace("%s", + column); } else { return "toDate(parseDateTimeBestEffort(toString(%s)))".replace("%s", column); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java index 9c367972b..96775f759 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/H2Adaptor.java @@ -9,18 +9,18 @@ public class H2Adaptor extends BaseDbAdaptor { public String getDateFormat(String dateType, String dateFormat, String column) { if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT_INT)) { if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) { - return "FORMATDATETIME(PARSEDATETIME(%s, 'yyyyMMdd'),'yyyy-MM')" - .replace("%s", column); + return "FORMATDATETIME(PARSEDATETIME(%s, 'yyyyMMdd'),'yyyy-MM')".replace("%s", + column); } else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) { return "DATE_TRUNC('week',%s)".replace("%s", column); } else { - return "FORMATDATETIME(PARSEDATETIME(%s, 'yyyyMMdd'),'yyyy-MM-dd')" - .replace("%s", column); + return "FORMATDATETIME(PARSEDATETIME(%s, 'yyyyMMdd'),'yyyy-MM-dd')".replace("%s", + column); } } else if (dateFormat.equalsIgnoreCase(Constants.DAY_FORMAT)) { if (TimeDimensionEnum.MONTH.name().equalsIgnoreCase(dateType)) { - return "FORMATDATETIME(PARSEDATETIME(%s, 'yyyy-MM-dd'),'yyyy-MM') " - .replace("%s", column); + return "FORMATDATETIME(PARSEDATETIME(%s, 'yyyy-MM-dd'),'yyyy-MM') ".replace("%s", + column); } else if (TimeDimensionEnum.WEEK.name().equalsIgnoreCase(dateType)) { return "DATE_TRUNC('week',%s)".replace("%s", column); } else { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java index 1b5ffcf3d..5e7c1e3e0 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/adaptor/db/PostgresqlAdaptor.java @@ -53,36 +53,30 @@ public class PostgresqlAdaptor extends BaseDbAdaptor { functionMap.put("DAY".toLowerCase(), "TO_CHAR"); functionMap.put("YEAR".toLowerCase(), "TO_CHAR"); Map functionCall = new HashMap<>(); - functionCall.put( - "MONTH".toLowerCase(), - o -> { - if (Objects.nonNull(o) && o instanceof ExpressionList) { - ExpressionList expressionList = (ExpressionList) o; - expressionList.add(new StringValue("MM")); - return expressionList; - } - return o; - }); - functionCall.put( - "DAY".toLowerCase(), - o -> { - if (Objects.nonNull(o) && o instanceof ExpressionList) { - ExpressionList expressionList = (ExpressionList) o; - expressionList.add(new StringValue("dd")); - return expressionList; - } - return o; - }); - functionCall.put( - "YEAR".toLowerCase(), - o -> { - if (Objects.nonNull(o) && o instanceof ExpressionList) { - ExpressionList expressionList = (ExpressionList) o; - expressionList.add(new StringValue("YYYY")); - return expressionList; - } - return o; - }); + functionCall.put("MONTH".toLowerCase(), o -> { + if (Objects.nonNull(o) && o instanceof ExpressionList) { + ExpressionList expressionList = (ExpressionList) o; + expressionList.add(new StringValue("MM")); + return expressionList; + } + return o; + }); + functionCall.put("DAY".toLowerCase(), o -> { + if (Objects.nonNull(o) && o instanceof ExpressionList) { + ExpressionList expressionList = (ExpressionList) o; + expressionList.add(new StringValue("dd")); + return expressionList; + } + return o; + }); + functionCall.put("YEAR".toLowerCase(), o -> { + if (Objects.nonNull(o) && o instanceof ExpressionList) { + ExpressionList expressionList = (ExpressionList) o; + expressionList.add(new StringValue("YYYY")); + return expressionList; + } + return o; + }); return SqlReplaceHelper.replaceFunction(sql, functionMap, functionCall); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheConfig.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheConfig.java index 1e9ac6220..142eb616f 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheConfig.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheConfig.java @@ -12,7 +12,8 @@ import java.util.concurrent.TimeUnit; @Configuration public class CaffeineCacheConfig { - @Autowired private CacheCommonConfig cacheCommonConfig; + @Autowired + private CacheCommonConfig cacheCommonConfig; @Value("${s2.caffeine.initial.capacity:500}") private Integer caffeineInitialCapacity; @@ -23,19 +24,14 @@ public class CaffeineCacheConfig { @Bean(name = "caffeineCache") public Cache caffeineCache() { return Caffeine.newBuilder() - .expireAfterWrite( - cacheCommonConfig.getCacheCommonExpireAfterWrite(), TimeUnit.MINUTES) - .initialCapacity(caffeineInitialCapacity) - .maximumSize(caffeineMaximumSize) - .build(); + .expireAfterWrite(cacheCommonConfig.getCacheCommonExpireAfterWrite(), + TimeUnit.MINUTES) + .initialCapacity(caffeineInitialCapacity).maximumSize(caffeineMaximumSize).build(); } @Bean(name = "searchCaffeineCache") public Cache searchCaffeineCache() { - return Caffeine.newBuilder() - .expireAfterWrite(10000, TimeUnit.MINUTES) - .initialCapacity(caffeineInitialCapacity) - .maximumSize(caffeineMaximumSize) - .build(); + return Caffeine.newBuilder().expireAfterWrite(10000, TimeUnit.MINUTES) + .initialCapacity(caffeineInitialCapacity).maximumSize(caffeineMaximumSize).build(); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheManager.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheManager.java index b7c53fa5b..3cf05df8d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheManager.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/CaffeineCacheManager.java @@ -12,7 +12,8 @@ import org.springframework.stereotype.Component; @Slf4j public class CaffeineCacheManager implements CacheManager { - @Autowired private CacheCommonConfig cacheCommonConfig; + @Autowired + private CacheCommonConfig cacheCommonConfig; @Autowired @Qualifier("caffeineCache") @@ -37,13 +38,9 @@ public class CaffeineCacheManager implements CacheManager { if (StringUtils.isEmpty(prefix)) { prefix = "-1"; } - return Joiner.on(":") - .join( - cacheCommonConfig.getCacheCommonApp(), - cacheCommonConfig.getCacheCommonEnv(), - cacheCommonConfig.getCacheCommonVersion(), - prefix, - body); + return Joiner.on(":").join(cacheCommonConfig.getCacheCommonApp(), + cacheCommonConfig.getCacheCommonEnv(), cacheCommonConfig.getCacheCommonVersion(), + prefix, body); } @Override diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java index c8593d7a3..95a2ce69a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/cache/DefaultQueryCache.java @@ -29,11 +29,10 @@ public class DefaultQueryCache implements QueryCache { CacheCommonConfig cacheCommonConfig = ContextUtils.getBean(CacheCommonConfig.class); if (cacheCommonConfig.getCacheEnable() && Objects.nonNull(value)) { CompletableFuture.supplyAsync(() -> cacheManager.put(cacheKey, value)) - .exceptionally( - exception -> { - log.warn("exception:", exception); - return null; - }); + .exceptionally(exception -> { + log.warn("exception:", exception); + return null; + }); log.debug("put to cache, key: {}", cacheKey); return true; } @@ -48,8 +47,8 @@ public class DefaultQueryCache implements QueryCache { } private String getKeyByModelIds(List modelIds) { - return String.join( - ",", modelIds.stream().map(Object::toString).collect(Collectors.toList())); + return String.join(",", + modelIds.stream().map(Object::toString).collect(Collectors.toList())); } private boolean isCache(SemanticQueryReq semanticQueryReq) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/AbstractAccelerator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/AbstractAccelerator.java index ac3d2e5cc..fd26da4d2 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/AbstractAccelerator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/AbstractAccelerator.java @@ -56,13 +56,9 @@ public abstract class AbstractAccelerator implements QueryAccelerator { public static final String MATERIALIZATION_SYS_PARTITION = "sys_partition"; /** check if a materialization match the fields and partitions */ - protected boolean check( - RelOptPlanner relOptPlanner, - RelBuilder relBuilder, - CalciteCatalogReader calciteCatalogReader, - Materialization materialization, - List fields, - List> partitions) { + protected boolean check(RelOptPlanner relOptPlanner, RelBuilder relBuilder, + CalciteCatalogReader calciteCatalogReader, Materialization materialization, + List fields, List> partitions) { if (!materialization.isPartitioned()) { return fields.stream().allMatch(f -> materialization.getColumns().contains(f)); } @@ -85,8 +81,8 @@ public abstract class AbstractAccelerator implements QueryAccelerator { } Materialization viewMaterialization = Materialization.builder().build(); - viewMaterialization.setName( - String.format("%s.%s", MATERIALIZATION_SYS_DB, MATERIALIZATION_SYS_VIEW)); + viewMaterialization + .setName(String.format("%s.%s", MATERIALIZATION_SYS_DB, MATERIALIZATION_SYS_VIEW)); viewMaterialization.setColumns(viewFieldList); addMaterialization(calciteCatalogReader.getRootSchema(), viewMaterialization); @@ -97,10 +93,8 @@ public abstract class AbstractAccelerator implements QueryAccelerator { queryMaterialization.setColumns(materializationFieldList); addMaterialization(calciteCatalogReader.getRootSchema(), queryMaterialization); - RelNode replacement = - relBuilder - .scan(Arrays.asList(MATERIALIZATION_SYS_DB, MATERIALIZATION_SYS_VIEW)) - .build(); + RelNode replacement = relBuilder + .scan(Arrays.asList(MATERIALIZATION_SYS_DB, MATERIALIZATION_SYS_VIEW)).build(); RelBuilder viewBuilder = relBuilder.scan(Arrays.asList(MATERIALIZATION_SYS_DB, MATERIALIZATION_SYS_SOURCE)); if (materialization.isPartitioned()) { @@ -117,9 +111,8 @@ public abstract class AbstractAccelerator implements QueryAccelerator { RelBuilder checkBuilder = relBuilder.scan(Arrays.asList(MATERIALIZATION_SYS_DB, MATERIALIZATION_SYS_SOURCE)); if (materialization.isPartitioned()) { - checkBuilder = - checkBuilder.filter( - getRexNode(checkBuilder, partitions, MATERIALIZATION_SYS_PARTITION)); + checkBuilder = checkBuilder + .filter(getRexNode(checkBuilder, partitions, MATERIALIZATION_SYS_PARTITION)); } RelNode checkRel = project(checkBuilder, queryFieldList).build(); relOptPlanner.setRoot(checkRel); @@ -135,12 +128,9 @@ public abstract class AbstractAccelerator implements QueryAccelerator { protected CalciteCatalogReader getCalciteCatalogReader() { CalciteCatalogReader calciteCatalogReader; CalciteSchema viewSchema = SchemaBuilder.getMaterializationSchema(); - calciteCatalogReader = - new CalciteCatalogReader( - CalciteSchema.from(viewSchema.plus()), - CalciteSchema.from(viewSchema.plus()).path(null), - Configuration.typeFactory, - new CalciteConnectionConfigImpl(new Properties())); + calciteCatalogReader = new CalciteCatalogReader(CalciteSchema.from(viewSchema.plus()), + CalciteSchema.from(viewSchema.plus()).path(null), Configuration.typeFactory, + new CalciteConnectionConfigImpl(new Properties())); return calciteCatalogReader; } @@ -151,8 +141,8 @@ public abstract class AbstractAccelerator implements QueryAccelerator { return relOptPlanner; } - protected RelBuilder builderMaterializationPlan( - CalciteCatalogReader calciteCatalogReader, RelOptPlanner relOptPlanner) { + protected RelBuilder builderMaterializationPlan(CalciteCatalogReader calciteCatalogReader, + RelOptPlanner relOptPlanner) { relOptPlanner.addRelTraitDef(ConventionTraitDef.INSTANCE); relOptPlanner.addRelTraitDef(RelDistributionTraitDef.INSTANCE); EnumerableRules.rules().forEach(relOptPlanner::addRule); @@ -161,8 +151,8 @@ public abstract class AbstractAccelerator implements QueryAccelerator { return RelFactories.LOGICAL_BUILDER.create(relOptCluster, calciteCatalogReader); } - protected void addMaterialization( - CalciteSchema dataSetSchema, Materialization materialization) { + protected void addMaterialization(CalciteSchema dataSetSchema, + Materialization materialization) { String[] dbTable = materialization.getName().split("\\."); String tb = dbTable[1].toLowerCase(); String db = dbTable[0].toLowerCase(); @@ -188,34 +178,28 @@ public abstract class AbstractAccelerator implements QueryAccelerator { protected Set extractTableNames(RelNode relNode) { Set tableNames = new HashSet<>(); - RelShuttle shuttle = - new RelHomogeneousShuttle() { - public RelNode visit(TableScan scan) { - RelOptTable table = scan.getTable(); - tableNames.addAll(table.getQualifiedName()); - return scan; - } - }; + RelShuttle shuttle = new RelHomogeneousShuttle() { + public RelNode visit(TableScan scan) { + RelOptTable table = scan.getTable(); + tableNames.addAll(table.getQualifiedName()); + return scan; + } + }; relNode.accept(shuttle); return tableNames; } - protected RexNode getRexNodeByTimeRange( - RelBuilder relBuilder, TimeRange timeRange, String field) { - return relBuilder.call( - SqlStdOperatorTable.AND, - relBuilder.call( - SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, - relBuilder.field(field), + protected RexNode getRexNodeByTimeRange(RelBuilder relBuilder, TimeRange timeRange, + String field) { + return relBuilder.call(SqlStdOperatorTable.AND, + relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, relBuilder.field(field), relBuilder.literal(timeRange.getStart())), - relBuilder.call( - SqlStdOperatorTable.LESS_THAN_OR_EQUAL, - relBuilder.field(field), + relBuilder.call(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, relBuilder.field(field), relBuilder.literal(timeRange.getEnd()))); } - protected RexNode getRexNode( - RelBuilder relBuilder, Materialization materialization, String viewField) { + protected RexNode getRexNode(RelBuilder relBuilder, Materialization materialization, + String viewField) { RexNode rexNode = null; for (String partition : materialization.getPartitions()) { TimeRange timeRange = TimeRange.builder().start(partition).end(partition).build(); @@ -223,43 +207,26 @@ public abstract class AbstractAccelerator implements QueryAccelerator { rexNode = getRexNodeByTimeRange(relBuilder, timeRange, viewField); continue; } - rexNode = - relBuilder.call( - SqlStdOperatorTable.OR, - rexNode, - getRexNodeByTimeRange(relBuilder, timeRange, viewField)); + rexNode = relBuilder.call(SqlStdOperatorTable.OR, rexNode, + getRexNodeByTimeRange(relBuilder, timeRange, viewField)); } return rexNode; } - protected RexNode getRexNode( - RelBuilder relBuilder, - List> timeRanges, - String viewField) { + protected RexNode getRexNode(RelBuilder relBuilder, + List> timeRanges, String viewField) { RexNode rexNode = null; for (ImmutablePair timeRange : timeRanges) { if (rexNode == null) { - rexNode = - getRexNodeByTimeRange( - relBuilder, - TimeRange.builder() - .start(timeRange.left) - .end(timeRange.right) - .build(), - viewField); + rexNode = getRexNodeByTimeRange(relBuilder, + TimeRange.builder().start(timeRange.left).end(timeRange.right).build(), + viewField); continue; } - rexNode = - relBuilder.call( - SqlStdOperatorTable.OR, - rexNode, - getRexNodeByTimeRange( - relBuilder, - TimeRange.builder() - .start(timeRange.left) - .end(timeRange.right) - .build(), - viewField)); + rexNode = relBuilder.call(SqlStdOperatorTable.OR, rexNode, + getRexNodeByTimeRange(relBuilder, + TimeRange.builder().start(timeRange.left).end(timeRange.right).build(), + viewField)); } return rexNode; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java index 42f3d5fe6..36a9d7ece 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/executor/JdbcExecutor.java @@ -28,8 +28,8 @@ public class JdbcExecutor implements QueryExecutor { SemanticQueryResp semanticQueryResp = queryAccelerator.query(queryStatement); if (Objects.nonNull(semanticQueryResp) && !semanticQueryResp.getResultList().isEmpty()) { - log.info( - "query by Accelerator {}", queryAccelerator.getClass().getSimpleName()); + log.info("query by Accelerator {}", + queryAccelerator.getClass().getSimpleName()); return semanticQueryResp; } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/DuckDbSource.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/DuckDbSource.java index 0b2ca7c6b..8e2c5219d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/DuckDbSource.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/DuckDbSource.java @@ -63,8 +63,8 @@ public class DuckDbSource { protected void init(JdbcTemplate jdbcTemplate) { jdbcTemplate.execute( String.format("SET memory_limit = '%sGB';", executorConfig.getMemoryLimit())); - jdbcTemplate.execute( - String.format("SET temp_directory='%s';", executorConfig.getDuckDbTemp())); + jdbcTemplate + .execute(String.format("SET temp_directory='%s';", executorConfig.getDuckDbTemp())); jdbcTemplate.execute(String.format("SET threads TO %s;", executorConfig.getThreads())); jdbcTemplate.execute("SET enable_object_cache = true;"); } @@ -82,23 +82,21 @@ public class DuckDbSource { } public void query(String sql, SemanticQueryResp queryResultWithColumns) { - duckDbJdbcTemplate.query( - sql, - rs -> { - if (null == rs) { - return queryResultWithColumns; - } - ResultSetMetaData metaData = rs.getMetaData(); - List queryColumns = new ArrayList<>(); - for (int i = 1; i <= metaData.getColumnCount(); i++) { - String key = metaData.getColumnLabel(i); - queryColumns.add(new QueryColumn(key, metaData.getColumnTypeName(i))); - } - queryResultWithColumns.setColumns(queryColumns); - List> resultList = buildResult(rs); - queryResultWithColumns.setResultList(resultList); - return queryResultWithColumns; - }); + duckDbJdbcTemplate.query(sql, rs -> { + if (null == rs) { + return queryResultWithColumns; + } + ResultSetMetaData metaData = rs.getMetaData(); + List queryColumns = new ArrayList<>(); + for (int i = 1; i <= metaData.getColumnCount(); i++) { + String key = metaData.getColumnLabel(i); + queryColumns.add(new QueryColumn(key, metaData.getColumnTypeName(i))); + } + queryResultWithColumns.setColumns(queryColumns); + List> resultList = buildResult(rs); + queryResultWithColumns.setResultList(resultList); + return queryResultWithColumns; + }); } public static List> buildResult(ResultSet resultSet) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/JdbcDataSource.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/JdbcDataSource.java index 7954cbd78..96b553b1d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/JdbcDataSource.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/pojo/JdbcDataSource.java @@ -203,10 +203,8 @@ public class JdbcDataSource { // default validation query String driverName = druidDataSource.getDriverClassName(); - if (driverName.indexOf("sqlserver") != -1 - || driverName.indexOf("mysql") != -1 - || driverName.indexOf("h2") != -1 - || driverName.indexOf("moonbox") != -1) { + if (driverName.indexOf("sqlserver") != -1 || driverName.indexOf("mysql") != -1 + || driverName.indexOf("h2") != -1 || driverName.indexOf("moonbox") != -1) { druidDataSource.setValidationQuery("select 1"); } @@ -242,12 +240,7 @@ public class JdbcDataSource { } private String getDataSourceKey(Database database) { - return JdbcDataSourceUtils.getKey( - database.getName(), - database.getUrl(), - database.getUsername(), - database.passwordDecrypt(), - "", - false); + return JdbcDataSourceUtils.getKey(database.getName(), database.getUrl(), + database.getUsername(), database.passwordDecrypt(), "", false); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java index ebcebd067..465309ba6 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DefaultSemanticTranslator.java @@ -57,16 +57,12 @@ public class DefaultSemanticTranslator implements SemanticTranslator { headlessConverter.convert(queryStatement); } } - log.debug( - "SemanticConverter after {} {} {}", - queryParam, - queryStatement.getDataSetQueryParam(), - queryStatement.getMetricQueryParam()); + log.debug("SemanticConverter after {} {} {}", queryParam, + queryStatement.getDataSetQueryParam(), queryStatement.getMetricQueryParam()); if (!queryStatement.getDataSetQueryParam().getSql().isEmpty()) { doParse(queryStatement.getDataSetQueryParam(), queryStatement); } else { - queryStatement - .getMetricQueryParam() + queryStatement.getMetricQueryParam() .setNativeQuery(queryParam.getQueryType().isNativeAggQuery()); doParse(queryStatement); } @@ -81,8 +77,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator { } } - public QueryStatement doParse( - DataSetQueryParam dataSetQueryParam, QueryStatement queryStatement) { + public QueryStatement doParse(DataSetQueryParam dataSetQueryParam, + QueryStatement queryStatement) { log.info("parse dataSetQuery [{}] ", dataSetQueryParam); SemanticModel semanticModel = queryStatement.getSemanticModel(); EngineType engineType = EngineType.fromString(semanticModel.getDatabase().getType()); @@ -91,9 +87,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator { List tables = new ArrayList<>(); boolean isSingleTable = dataSetQueryParam.getTables().size() == 1; for (MetricTable metricTable : dataSetQueryParam.getTables()) { - QueryStatement tableSql = - parserSql( - metricTable, isSingleTable, dataSetQueryParam, queryStatement); + QueryStatement tableSql = parserSql(metricTable, isSingleTable, + dataSetQueryParam, queryStatement); if (isSingleTable && StringUtils.isNotBlank(tableSql.getDataSetSimplifySql())) { queryStatement.setSql(tableSql.getDataSetSimplifySql()); queryStatement.setDataSetQueryParam(dataSetQueryParam); @@ -108,26 +103,13 @@ public class DefaultSemanticTranslator implements SemanticTranslator { tables.stream().map(table -> table[0]).collect(Collectors.toList()); List parentSqlList = tables.stream().map(table -> table[1]).collect(Collectors.toList()); - sql = - SqlMergeWithUtils.mergeWith( - engineType, - dataSetQueryParam.getSql(), - parentSqlList, - parentWithNameList); + sql = SqlMergeWithUtils.mergeWith(engineType, dataSetQueryParam.getSql(), + parentSqlList, parentWithNameList); } else { sql = dataSetQueryParam.getSql(); for (String[] tb : tables) { - sql = - StringUtils.replace( - sql, - tb[0], - "(" - + tb[1] - + ") " - + (dataSetQueryParam.isWithAlias() - ? "" - : tb[0]), - -1); + sql = StringUtils.replace(sql, tb[0], "(" + tb[1] + ") " + + (dataSetQueryParam.isWithAlias() ? "" : tb[0]), -1); } } queryStatement.setSql(sql); @@ -143,8 +125,7 @@ public class DefaultSemanticTranslator implements SemanticTranslator { } public QueryStatement doParse(QueryStatement queryStatement) { - return doParse( - queryStatement, + return doParse(queryStatement, AggOption.getAggregation(queryStatement.getMetricQueryParam().isNativeQuery())); } @@ -160,12 +141,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator { return queryStatement; } - private QueryStatement parserSql( - MetricTable metricTable, - Boolean isSingleMetricTable, - DataSetQueryParam dataSetQueryParam, - QueryStatement queryStatement) - throws Exception { + private QueryStatement parserSql(MetricTable metricTable, Boolean isSingleMetricTable, + DataSetQueryParam dataSetQueryParam, QueryStatement queryStatement) throws Exception { MetricQueryParam metricReq = new MetricQueryParam(); metricReq.setMetrics(metricTable.getMetrics()); metricReq.setDimensions(metricTable.getDimensions()); @@ -184,10 +161,8 @@ public class DefaultSemanticTranslator implements SemanticTranslator { } tableSql = doParse(tableSql, metricTable.getAggOption()); if (!tableSql.isOk()) { - throw new Exception( - String.format( - "parser table [%s] error [%s]", - metricTable.getAlias(), tableSql.getErrMsg())); + throw new Exception(String.format("parser table [%s] error [%s]", + metricTable.getAlias(), tableSql.getErrMsg())); } return tableSql; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DetailQueryOptimizer.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DetailQueryOptimizer.java index e18ad0eab..bf84cbb2b 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DetailQueryOptimizer.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/DetailQueryOptimizer.java @@ -27,11 +27,8 @@ public class DetailQueryOptimizer implements QueryOptimizer { if (queryParam.getMetrics().size() == 0 && !CollectionUtils.isEmpty(queryParam.getGroups())) { String sqlForm = "select %s from ( %s ) src_no_metric"; - String sql = - String.format( - sqlForm, - queryParam.getGroups().stream().collect(Collectors.joining(",")), - sqlRaw); + String sql = String.format(sqlForm, + queryParam.getGroups().stream().collect(Collectors.joining(",")), sqlRaw); queryStatement.setSql(sql); } } @@ -39,8 +36,7 @@ public class DetailQueryOptimizer implements QueryOptimizer { } public boolean isDetailQuery(QueryParam queryParam) { - return Objects.nonNull(queryParam) - && queryParam.getQueryType().isNativeAggQuery() + return Objects.nonNull(queryParam) && queryParam.getQueryType().isNativeAggQuery() && CollectionUtils.isEmpty(queryParam.getMetrics()); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/CalciteQueryParser.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/CalciteQueryParser.java index ba3db03e7..61589d72b 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/CalciteQueryParser.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/CalciteQueryParser.java @@ -41,14 +41,10 @@ public class CalciteQueryParser implements QueryParser { && Objects.nonNull(queryStatement.getDataSetAlias()) && !queryStatement.getDataSetAlias().isEmpty()) { // simplify model sql with query sql - String simplifySql = - aggBuilder.simplify( - getSqlByDataSet( - engineType, - aggBuilder.getSql(engineType), - queryStatement.getDataSetSql(), - queryStatement.getDataSetAlias()), - engineType); + String simplifySql = aggBuilder.simplify( + getSqlByDataSet(engineType, aggBuilder.getSql(engineType), + queryStatement.getDataSetSql(), queryStatement.getDataSetAlias()), + engineType); if (Objects.nonNull(simplifySql) && !simplifySql.isEmpty()) { log.debug("simplifySql [{}]", simplifySql); queryStatement.setDataSetSimplifySql(simplifySql); @@ -56,8 +52,8 @@ public class CalciteQueryParser implements QueryParser { } } - private SemanticSchema getSemanticSchema( - SemanticModel semanticModel, QueryStatement queryStatement) { + private SemanticSchema getSemanticSchema(SemanticModel semanticModel, + QueryStatement queryStatement) { SemanticSchema semanticSchema = SemanticSchema.newBuilder(semanticModel.getSchemaKey()).build(); semanticSchema.setSemanticModel(semanticModel); @@ -66,20 +62,14 @@ public class CalciteQueryParser implements QueryParser { semanticSchema.setMetric(semanticModel.getMetrics()); semanticSchema.setJoinRelations(semanticModel.getJoinRelations()); semanticSchema.setRuntimeOptions( - RuntimeOptions.builder() - .minMaxTime(queryStatement.getMinMaxTime()) - .enableOptimize(queryStatement.getEnableOptimize()) - .build()); + RuntimeOptions.builder().minMaxTime(queryStatement.getMinMaxTime()) + .enableOptimize(queryStatement.getEnableOptimize()).build()); return semanticSchema; } - private String getSqlByDataSet( - EngineType engineType, String parentSql, String dataSetSql, String parentAlias) - throws SqlParseException { - return SqlMergeWithUtils.mergeWith( - engineType, - dataSetSql, - Collections.singletonList(parentSql), - Collections.singletonList(parentAlias)); + private String getSqlByDataSet(EngineType engineType, String parentSql, String dataSetSql, + String parentAlias) throws SqlParseException { + return SqlMergeWithUtils.mergeWith(engineType, dataSetSql, + Collections.singletonList(parentSql), Collections.singletonList(parentAlias)); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/planner/AggPlanner.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/planner/AggPlanner.java index 1325022f1..fd12cf821 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/planner/AggPlanner.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/planner/AggPlanner.java @@ -56,7 +56,7 @@ public class AggPlanner implements Planner { isAgg = getAgg(datasource.get(0)); sourceId = String.valueOf(datasource.get(0).getSourceId()); - // build level by level + // build level by level LinkedList builders = new LinkedList<>(); builders.add(new SourceRender()); builders.add(new FilterRender()); @@ -68,9 +68,8 @@ public class AggPlanner implements Planner { Renderer renderer = it.next(); if (previous != null) { previous.render(metricReq, datasource, scope, schema, !isAgg); - renderer.setTable( - previous.builderAs( - DataSourceNode.getNames(datasource) + "_" + String.valueOf(i))); + renderer.setTable(previous + .builderAs(DataSourceNode.getNames(datasource) + "_" + String.valueOf(i))); i++; } previous = renderer; @@ -88,10 +87,8 @@ public class AggPlanner implements Planner { return AggOption.isAgg(aggOption); } // default by dataSource time aggregation - if (Objects.nonNull(dataSource.getAggTime()) - && !dataSource - .getAggTime() - .equalsIgnoreCase(Constants.DIMENSION_TYPE_TIME_GRANULARITY_NONE)) { + if (Objects.nonNull(dataSource.getAggTime()) && !dataSource.getAggTime() + .equalsIgnoreCase(Constants.DIMENSION_TYPE_TIME_GRANULARITY_NONE)) { if (!metricReq.isNativeQuery()) { return true; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Identify.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Identify.java index 11024c147..c8909003d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Identify.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Identify.java @@ -10,8 +10,7 @@ import lombok.NoArgsConstructor; public class Identify { public enum Type { - PRIMARY, - FOREIGN + PRIMARY, FOREIGN } private String name; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Materialization.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Materialization.java index 927c275b6..d4166e862 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Materialization.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/Materialization.java @@ -15,10 +15,8 @@ public class Materialization { * partition time type 1 - FULL, not use partition 2 - PARTITION , use time list 3 - ZIPPER, * use [startDate, endDate] range time */ - FULL("FULL"), - PARTITION("PARTITION"), - ZIPPER("ZIPPER"), - None(""); + FULL("FULL"), PARTITION("PARTITION"), ZIPPER("ZIPPER"), None(""); + private String name; TimePartType(String name) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/SemanticModel.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/SemanticModel.java index 71d33ed06..40b4e8171 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/SemanticModel.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/s2sql/SemanticModel.java @@ -22,8 +22,7 @@ public class SemanticModel { private Database database; public List getDimensions() { - return dimensionMap.values().stream() - .flatMap(Collection::stream) + return dimensionMap.values().stream().flatMap(Collection::stream) .collect(Collectors.toList()); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/DataSourceTable.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/DataSourceTable.java index 55bb36ede..e5e11f6dc 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/DataSourceTable.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/DataSourceTable.java @@ -32,10 +32,7 @@ public class DataSourceTable extends AbstractTable implements ScannableTable, Tr private RelDataType rowType; - private DataSourceTable( - String tableName, - List fieldNames, - List fieldTypes, + private DataSourceTable(String tableName, List fieldNames, List fieldTypes, Statistic statistic) { this.tableName = tableName; this.fieldNames = fieldNames; @@ -80,8 +77,8 @@ public class DataSourceTable extends AbstractTable implements ScannableTable, Tr public RelNode toRel(RelOptTable.ToRelContext toRelContext, RelOptTable relOptTable) { List hint = new ArrayList<>(); - return new LogicalTableScan( - toRelContext.getCluster(), toRelContext.getCluster().traitSet(), hint, relOptTable); + return new LogicalTableScan(toRelContext.getCluster(), toRelContext.getCluster().traitSet(), + hint, relOptTable); } public static final class Builder { @@ -128,8 +125,8 @@ public class DataSourceTable extends AbstractTable implements ScannableTable, Tr throw new IllegalStateException("Table must have positive row count"); } - return new DataSourceTable( - tableName, fieldNames, fieldTypes, Statistics.of(rowCount, null)); + return new DataSourceTable(tableName, fieldNames, fieldTypes, + Statistics.of(rowCount, null)); } } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/SchemaBuilder.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/SchemaBuilder.java index 05c173b2f..8cad09dac 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/SchemaBuilder.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/schema/SchemaBuilder.java @@ -31,50 +31,35 @@ public class SchemaBuilder { Map nameToTypeMap = new HashMap<>(); CalciteSchema rootSchema = CalciteSchema.createRootSchema(true, false); rootSchema.add(schema.getSchemaKey(), schema); - Prepare.CatalogReader catalogReader = - new CalciteCatalogReader( - rootSchema, - Collections.singletonList(schema.getSchemaKey()), - Configuration.typeFactory, - Configuration.config); + Prepare.CatalogReader catalogReader = new CalciteCatalogReader(rootSchema, + Collections.singletonList(schema.getSchemaKey()), Configuration.typeFactory, + Configuration.config); EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); S2SQLSqlValidatorImpl s2SQLSqlValidator = - new S2SQLSqlValidatorImpl( - Configuration.operatorTable, - catalogReader, - Configuration.typeFactory, - Configuration.getValidatorConfig(engineType)); + new S2SQLSqlValidatorImpl(Configuration.operatorTable, catalogReader, + Configuration.typeFactory, Configuration.getValidatorConfig(engineType)); return new ParameterScope(s2SQLSqlValidator, nameToTypeMap); } public static CalciteSchema getMaterializationSchema() { CalciteSchema rootSchema = CalciteSchema.createRootSchema(true, false); SchemaPlus schema = rootSchema.plus().add(MATERIALIZATION_SYS_DB, new AbstractSchema()); - DataSourceTable srcTable = - DataSourceTable.newBuilder(MATERIALIZATION_SYS_SOURCE) - .addField(MATERIALIZATION_SYS_FIELD_DATE, SqlTypeName.DATE) - .addField(MATERIALIZATION_SYS_FIELD_DATA, SqlTypeName.BIGINT) - .withRowCount(1) - .build(); + DataSourceTable srcTable = DataSourceTable.newBuilder(MATERIALIZATION_SYS_SOURCE) + .addField(MATERIALIZATION_SYS_FIELD_DATE, SqlTypeName.DATE) + .addField(MATERIALIZATION_SYS_FIELD_DATA, SqlTypeName.BIGINT).withRowCount(1) + .build(); schema.add(MATERIALIZATION_SYS_SOURCE, srcTable); - DataSourceTable dataSetTable = - DataSourceTable.newBuilder(MATERIALIZATION_SYS_VIEW) - .addField(MATERIALIZATION_SYS_FIELD_DATE, SqlTypeName.DATE) - .addField(MATERIALIZATION_SYS_FIELD_DATA, SqlTypeName.BIGINT) - .withRowCount(1) - .build(); + DataSourceTable dataSetTable = DataSourceTable.newBuilder(MATERIALIZATION_SYS_VIEW) + .addField(MATERIALIZATION_SYS_FIELD_DATE, SqlTypeName.DATE) + .addField(MATERIALIZATION_SYS_FIELD_DATA, SqlTypeName.BIGINT).withRowCount(1) + .build(); schema.add(MATERIALIZATION_SYS_VIEW, dataSetTable); return rootSchema; } - public static void addSourceView( - CalciteSchema dataSetSchema, - String dbSrc, - String tbSrc, - Set dates, - Set dimensions, - Set metrics) { + public static void addSourceView(CalciteSchema dataSetSchema, String dbSrc, String tbSrc, + Set dates, Set dimensions, Set metrics) { String tb = tbSrc; String db = dbSrc; DataSourceTable.Builder builder = DataSourceTable.newBuilder(tb); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/Renderer.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/Renderer.java index 26357f006..088a98e99 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/Renderer.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/Renderer.java @@ -28,38 +28,28 @@ public abstract class Renderer { protected TableView tableView = new TableView(); public static Optional getDimensionByName(String name, DataSource datasource) { - return datasource.getDimensions().stream() - .filter(d -> d.getName().equalsIgnoreCase(name)) + return datasource.getDimensions().stream().filter(d -> d.getName().equalsIgnoreCase(name)) .findFirst(); } public static Optional getMeasureByName(String name, DataSource datasource) { - return datasource.getMeasures().stream() - .filter(mm -> mm.getName().equalsIgnoreCase(name)) + return datasource.getMeasures().stream().filter(mm -> mm.getName().equalsIgnoreCase(name)) .findFirst(); } public static Optional getMetricByName(String name, SemanticSchema schema) { - Optional metric = - schema.getMetrics().stream() - .filter(m -> m.getName().equalsIgnoreCase(name)) - .findFirst(); + Optional metric = schema.getMetrics().stream() + .filter(m -> m.getName().equalsIgnoreCase(name)).findFirst(); return metric; } public static Optional getIdentifyByName(String name, DataSource datasource) { - return datasource.getIdentifiers().stream() - .filter(i -> i.getName().equalsIgnoreCase(name)) + return datasource.getIdentifiers().stream().filter(i -> i.getName().equalsIgnoreCase(name)) .findFirst(); } - public static MetricNode buildMetricNode( - String metric, - DataSource datasource, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg, - String alias) + public static MetricNode buildMetricNode(String metric, DataSource datasource, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg, String alias) throws Exception { Optional metricOpt = getMetricByName(metric, schema); MetricNode metricNode = new MetricNode(); @@ -69,61 +59,38 @@ public abstract class Renderer { for (Measure m : metricOpt.get().getMetricTypeParams().getMeasures()) { Optional measure = getMeasureByName(m.getName(), datasource); if (measure.isPresent()) { - metricNode - .getNonAggNode() - .put( - measure.get().getName(), - MeasureNode.buildNonAgg( - alias, measure.get(), scope, engineType)); - metricNode - .getAggNode() - .put( - measure.get().getName(), - MeasureNode.buildAgg(measure.get(), nonAgg, scope, engineType)); - metricNode - .getAggFunction() - .put(measure.get().getName(), measure.get().getAgg()); + metricNode.getNonAggNode().put(measure.get().getName(), + MeasureNode.buildNonAgg(alias, measure.get(), scope, engineType)); + metricNode.getAggNode().put(measure.get().getName(), + MeasureNode.buildAgg(measure.get(), nonAgg, scope, engineType)); + metricNode.getAggFunction().put(measure.get().getName(), + measure.get().getAgg()); } else { - metricNode - .getNonAggNode() - .put(m.getName(), MeasureNode.buildNonAgg(alias, m, scope, engineType)); - metricNode - .getAggNode() - .put(m.getName(), MeasureNode.buildAgg(m, nonAgg, scope, engineType)); + metricNode.getNonAggNode().put(m.getName(), + MeasureNode.buildNonAgg(alias, m, scope, engineType)); + metricNode.getAggNode().put(m.getName(), + MeasureNode.buildAgg(m, nonAgg, scope, engineType)); metricNode.getAggFunction().put(m.getName(), m.getAgg()); } if (m.getConstraint() != null && !m.getConstraint().isEmpty()) { - metricNode - .getMeasureFilter() - .put( - m.getName(), - SemanticNode.parse(m.getConstraint(), scope, engineType)); + metricNode.getMeasureFilter().put(m.getName(), + SemanticNode.parse(m.getConstraint(), scope, engineType)); } } return metricNode; } Optional measure = getMeasureByName(metric, datasource); if (measure.isPresent()) { - metricNode - .getNonAggNode() - .put( - measure.get().getName(), - MeasureNode.buildNonAgg(alias, measure.get(), scope, engineType)); - metricNode - .getAggNode() - .put( - measure.get().getName(), - MeasureNode.buildAgg(measure.get(), nonAgg, scope, engineType)); + metricNode.getNonAggNode().put(measure.get().getName(), + MeasureNode.buildNonAgg(alias, measure.get(), scope, engineType)); + metricNode.getAggNode().put(measure.get().getName(), + MeasureNode.buildAgg(measure.get(), nonAgg, scope, engineType)); metricNode.getAggFunction().put(measure.get().getName(), measure.get().getAgg()); if (measure.get().getConstraint() != null && !measure.get().getConstraint().isEmpty()) { - metricNode - .getMeasureFilter() - .put( - measure.get().getName(), - SemanticNode.parse( - measure.get().getConstraint(), scope, engineType)); + metricNode.getMeasureFilter().put(measure.get().getName(), + SemanticNode.parse(measure.get().getConstraint(), scope, engineType)); } } return metricNode; @@ -146,11 +113,6 @@ public abstract class Renderer { return SemanticNode.buildAs(alias, tableView.build()); } - public abstract void render( - MetricQueryParam metricCommand, - List dataSources, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception; + public abstract void render(MetricQueryParam metricCommand, List dataSources, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/S2SQLSqlValidatorImpl.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/S2SQLSqlValidatorImpl.java index a3d5c3eda..b3c7c0a57 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/S2SQLSqlValidatorImpl.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/S2SQLSqlValidatorImpl.java @@ -8,11 +8,8 @@ import org.apache.calcite.sql.validate.SqlValidatorImpl; /** customize the SqlValidatorImpl */ public class S2SQLSqlValidatorImpl extends SqlValidatorImpl { - public S2SQLSqlValidatorImpl( - SqlOperatorTable opTab, - SqlValidatorCatalogReader catalogReader, - RelDataTypeFactory typeFactory, - Config config) { + public S2SQLSqlValidatorImpl(SqlOperatorTable opTab, SqlValidatorCatalogReader catalogReader, + RelDataTypeFactory typeFactory, Config config) { super(opTab, catalogReader, typeFactory, config); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/TableView.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/TableView.java index fd784a88c..8c21132f9 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/TableView.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/TableView.java @@ -39,29 +39,16 @@ public class TableView { if (filter.size() > 0) { filterNodeList = new SqlNodeList(filter, SqlParserPos.ZERO); } - return new SqlSelect( - SqlParserPos.ZERO, - null, - new SqlNodeList(measure, SqlParserPos.ZERO), - table, - filterNodeList, - dimensionNodeList, - null, - null, - null, - order, - offset, - fetch, + return new SqlSelect(SqlParserPos.ZERO, null, new SqlNodeList(measure, SqlParserPos.ZERO), + table, filterNodeList, dimensionNodeList, null, null, null, order, offset, fetch, null); } private List getGroup(List sqlNodeList) { return sqlNodeList.stream() - .map( - s -> - (s.getKind().equals(SqlKind.AS) - ? ((SqlBasicCall) s).getOperandList().get(0) - : s)) + .map(s -> (s.getKind().equals(SqlKind.AS) + ? ((SqlBasicCall) s).getOperandList().get(0) + : s)) .collect(Collectors.toList()); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/AggFunctionNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/AggFunctionNode.java index c1e1dc446..54d0b7981 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/AggFunctionNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/AggFunctionNode.java @@ -8,33 +8,19 @@ import java.util.Objects; public class AggFunctionNode extends SemanticNode { - public static SqlNode build( - String agg, String name, SqlValidatorScope scope, EngineType engineType) - throws Exception { + public static SqlNode build(String agg, String name, SqlValidatorScope scope, + EngineType engineType) throws Exception { if (Objects.isNull(agg) || agg.isEmpty()) { return parse(name, scope, engineType); } if (AggFunction.COUNT_DISTINCT.name().equalsIgnoreCase(agg)) { - return parse( - AggFunction.COUNT.name() - + " ( " - + AggFunction.DISTINCT.name() - + " " - + name - + " ) ", - scope, - engineType); + return parse(AggFunction.COUNT.name() + " ( " + AggFunction.DISTINCT.name() + " " + name + + " ) ", scope, engineType); } return parse(agg + " ( " + name + " ) ", scope, engineType); } public static enum AggFunction { - AVG, - COUNT_DISTINCT, - MAX, - MIN, - SUM, - COUNT, - DISTINCT + AVG, COUNT_DISTINCT, MAX, MIN, SUM, COUNT, DISTINCT } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataSourceNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataSourceNode.java index a30c4ddbe..ae3fab3a5 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataSourceNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DataSourceNode.java @@ -46,9 +46,8 @@ public class DataSourceNode extends SemanticNode { sqlTable = datasource.getSqlQuery(); } else if (datasource.getTableQuery() != null && !datasource.getTableQuery().isEmpty()) { if (datasource.getType().equalsIgnoreCase(EngineType.POSTGRESQL.getName())) { - String fullTableName = - Arrays.stream(datasource.getTableQuery().split("\\.")) - .collect(Collectors.joining(".public.")); + String fullTableName = Arrays.stream(datasource.getTableQuery().split("\\.")) + .collect(Collectors.joining(".public.")); sqlTable = "select * from " + fullTableName; } else { sqlTable = "select * from " + datasource.getTableQuery(); @@ -76,13 +75,8 @@ public class DataSourceNode extends SemanticNode { } } - private static void addSchemaTable( - SqlValidatorScope scope, - DataSource datasource, - String db, - String tb, - Set fields) - throws Exception { + private static void addSchemaTable(SqlValidatorScope scope, DataSource datasource, String db, + String tb, Set fields) throws Exception { Set dateInfo = new HashSet<>(); Set dimensions = new HashSet<>(); Set metrics = new HashSet<>(); @@ -99,13 +93,11 @@ public class DataSourceNode extends SemanticNode { for (Measure m : datasource.getMeasures()) { List identifiers = expand(SemanticNode.parse(m.getExpr(), scope, engineType), scope); - identifiers.stream() - .forEach( - i -> { - if (!dimensions.contains(i.toString())) { - metrics.add(i.toString()); - } - }); + identifiers.stream().forEach(i -> { + if (!dimensions.contains(i.toString())) { + metrics.add(i.toString()); + } + }); if (!dimensions.contains(m.getName())) { metrics.add(m.getName()); } @@ -116,46 +108,32 @@ public class DataSourceNode extends SemanticNode { log.info("add column {} {}", datasource.getName(), field); } } - SchemaBuilder.addSourceView( - scope.getValidator().getCatalogReader().getRootSchema(), - db, - tb, - dateInfo, - dimensions, - metrics); + SchemaBuilder.addSourceView(scope.getValidator().getCatalogReader().getRootSchema(), db, tb, + dateInfo, dimensions, metrics); } - public static SqlNode buildExtend( - DataSource datasource, Map exprList, SqlValidatorScope scope) - throws Exception { + public static SqlNode buildExtend(DataSource datasource, Map exprList, + SqlValidatorScope scope) throws Exception { if (CollectionUtils.isEmpty(exprList)) { return build(datasource, scope); } EngineType engineType = EngineType.fromString(datasource.getType()); - SqlNode dataSet = - new SqlBasicCall( - new LateralViewExplodeNode(exprList), - Arrays.asList( - build(datasource, scope), - new SqlNodeList( - getExtendField(exprList, scope, engineType), - SqlParserPos.ZERO)), - SqlParserPos.ZERO); + SqlNode dataSet = new SqlBasicCall(new LateralViewExplodeNode(exprList), + Arrays.asList(build(datasource, scope), new SqlNodeList( + getExtendField(exprList, scope, engineType), SqlParserPos.ZERO)), + SqlParserPos.ZERO); return buildAs(datasource.getName() + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, dataSet); } - public static List getExtendField( - Map exprList, SqlValidatorScope scope, EngineType engineType) - throws Exception { + public static List getExtendField(Map exprList, + SqlValidatorScope scope, EngineType engineType) throws Exception { List sqlNodeList = new ArrayList<>(); for (String expr : exprList.keySet()) { sqlNodeList.add(parse(expr, scope, engineType)); - sqlNodeList.add( - new SqlDataTypeSpec( - new SqlUserDefinedTypeNameSpec( - expr + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, - SqlParserPos.ZERO), - SqlParserPos.ZERO)); + sqlNodeList.add(new SqlDataTypeSpec( + new SqlUserDefinedTypeNameSpec(expr + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, + SqlParserPos.ZERO), + SqlParserPos.ZERO)); } return sqlNodeList; } @@ -172,45 +150,31 @@ public class DataSourceNode extends SemanticNode { return dataSourceList.stream().map(d -> d.getName()).collect(Collectors.joining("_")); } - public static void getQueryDimensionMeasure( - SemanticSchema schema, - MetricQueryParam metricCommand, - Set queryDimension, - List measures) { - queryDimension.addAll( - metricCommand.getDimensions().stream() - .map( - d -> - d.contains(Constants.DIMENSION_IDENTIFY) - ? d.split(Constants.DIMENSION_IDENTIFY)[1] - : d) - .collect(Collectors.toSet())); + public static void getQueryDimensionMeasure(SemanticSchema schema, + MetricQueryParam metricCommand, Set queryDimension, List measures) { + queryDimension.addAll(metricCommand.getDimensions().stream() + .map(d -> d.contains(Constants.DIMENSION_IDENTIFY) + ? d.split(Constants.DIMENSION_IDENTIFY)[1] + : d) + .collect(Collectors.toSet())); Set schemaMetricName = schema.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); - schema.getMetrics().stream() - .filter(m -> metricCommand.getMetrics().contains(m.getName())) - .forEach( - m -> - m.getMetricTypeParams().getMeasures().stream() - .forEach(mm -> measures.add(mm.getName()))); - metricCommand.getMetrics().stream() - .filter(m -> !schemaMetricName.contains(m)) + schema.getMetrics().stream().filter(m -> metricCommand.getMetrics().contains(m.getName())) + .forEach(m -> m.getMetricTypeParams().getMeasures().stream() + .forEach(mm -> measures.add(mm.getName()))); + metricCommand.getMetrics().stream().filter(m -> !schemaMetricName.contains(m)) .forEach(m -> measures.add(m)); } - public static void mergeQueryFilterDimensionMeasure( - SemanticSchema schema, - MetricQueryParam metricCommand, - Set queryDimension, - List measures, - SqlValidatorScope scope) - throws Exception { + public static void mergeQueryFilterDimensionMeasure(SemanticSchema schema, + MetricQueryParam metricCommand, Set queryDimension, List measures, + SqlValidatorScope scope) throws Exception { EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); if (Objects.nonNull(metricCommand.getWhere()) && !metricCommand.getWhere().isEmpty()) { Set filterConditions = new HashSet<>(); - FilterNode.getFilterField( - parse(metricCommand.getWhere(), scope, engineType), filterConditions); + FilterNode.getFilterField(parse(metricCommand.getWhere(), scope, engineType), + filterConditions); Set queryMeasures = new HashSet<>(measures); Set schemaMetricName = schema.getMetrics().stream().map(m -> m.getName()).collect(Collectors.toSet()); @@ -218,11 +182,8 @@ public class DataSourceNode extends SemanticNode { if (schemaMetricName.contains(filterCondition)) { schema.getMetrics().stream() .filter(m -> m.getName().equalsIgnoreCase(filterCondition)) - .forEach( - m -> - m.getMetricTypeParams().getMeasures().stream() - .forEach( - mm -> queryMeasures.add(mm.getName()))); + .forEach(m -> m.getMetricTypeParams().getMeasures().stream() + .forEach(mm -> queryMeasures.add(mm.getName()))); continue; } queryDimension.add(filterCondition); @@ -232,9 +193,8 @@ public class DataSourceNode extends SemanticNode { } } - public static List getMatchDataSources( - SqlValidatorScope scope, SemanticSchema schema, MetricQueryParam metricCommand) - throws Exception { + public static List getMatchDataSources(SqlValidatorScope scope, + SemanticSchema schema, MetricQueryParam metricCommand) throws Exception { List dataSources = new ArrayList<>(); // check by metric @@ -245,18 +205,14 @@ public class DataSourceNode extends SemanticNode { // one , match measure count Map dataSourceMeasures = new HashMap<>(); for (Map.Entry entry : schema.getDatasource().entrySet()) { - Set sourceMeasure = - entry.getValue().getMeasures().stream() - .map(mm -> mm.getName()) - .collect(Collectors.toSet()); + Set sourceMeasure = entry.getValue().getMeasures().stream() + .map(mm -> mm.getName()).collect(Collectors.toSet()); sourceMeasure.retainAll(measures); dataSourceMeasures.put(entry.getKey(), sourceMeasure.size()); } log.info("dataSourceMeasures [{}]", dataSourceMeasures); - Optional> base = - dataSourceMeasures.entrySet().stream() - .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())) - .findFirst(); + Optional> base = dataSourceMeasures.entrySet().stream() + .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())).findFirst(); if (base.isPresent()) { baseDataSource = schema.getDatasource().get(base.get().getKey()); dataSources.add(baseDataSource); @@ -264,14 +220,10 @@ public class DataSourceNode extends SemanticNode { // second , check match all dimension and metric if (baseDataSource != null) { Set filterMeasure = new HashSet<>(); - Set sourceMeasure = - baseDataSource.getMeasures().stream() - .map(mm -> mm.getName()) - .collect(Collectors.toSet()); - Set dimension = - baseDataSource.getDimensions().stream() - .map(dd -> dd.getName()) - .collect(Collectors.toSet()); + Set sourceMeasure = baseDataSource.getMeasures().stream() + .map(mm -> mm.getName()).collect(Collectors.toSet()); + Set dimension = baseDataSource.getDimensions().stream().map(dd -> dd.getName()) + .collect(Collectors.toSet()); baseDataSource.getIdentifiers().stream().forEach(i -> dimension.add(i.getName())); if (schema.getDimension().containsKey(baseDataSource.getName())) { schema.getDimension().get(baseDataSource.getName()).stream() @@ -281,43 +233,31 @@ public class DataSourceNode extends SemanticNode { filterMeasure.addAll(dimension); EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); - mergeQueryFilterDimensionMeasure( - schema, metricCommand, queryDimension, measures, scope); - boolean isAllMatch = - checkMatch( - sourceMeasure, - queryDimension, - measures, - dimension, - metricCommand, - scope, - engineType); + mergeQueryFilterDimensionMeasure(schema, metricCommand, queryDimension, measures, + scope); + boolean isAllMatch = checkMatch(sourceMeasure, queryDimension, measures, dimension, + metricCommand, scope, engineType); if (isAllMatch) { log.debug("baseDataSource match all "); return dataSources; } // find all dataSource has the same identifiers - List linkDataSources = - getLinkDataSourcesByJoinRelation( - queryDimension, measures, baseDataSource, schema); + List linkDataSources = getLinkDataSourcesByJoinRelation(queryDimension, + measures, baseDataSource, schema); if (CollectionUtils.isEmpty(linkDataSources)) { log.debug("baseDataSource get by identifiers "); - Set baseIdentifiers = - baseDataSource.getIdentifiers().stream() - .map(i -> i.getName()) - .collect(Collectors.toSet()); + Set baseIdentifiers = baseDataSource.getIdentifiers().stream() + .map(i -> i.getName()).collect(Collectors.toSet()); if (baseIdentifiers.isEmpty()) { throw new Exception( "datasource error : " + baseDataSource.getName() + " miss identifier"); } - linkDataSources = - getLinkDataSources( - baseIdentifiers, queryDimension, measures, baseDataSource, schema); + linkDataSources = getLinkDataSources(baseIdentifiers, queryDimension, measures, + baseDataSource, schema); if (linkDataSources.isEmpty()) { - throw new Exception( - String.format( - "not find the match datasource : dimension[%s],measure[%s]", - queryDimension, measures)); + throw new Exception(String.format( + "not find the match datasource : dimension[%s],measure[%s]", + queryDimension, measures)); } } log.debug("linkDataSources {}", linkDataSources); @@ -328,15 +268,9 @@ public class DataSourceNode extends SemanticNode { return dataSources; } - private static boolean checkMatch( - Set sourceMeasure, - Set queryDimension, - List measures, - Set dimension, - MetricQueryParam metricCommand, - SqlValidatorScope scope, - EngineType engineType) - throws Exception { + private static boolean checkMatch(Set sourceMeasure, Set queryDimension, + List measures, Set dimension, MetricQueryParam metricCommand, + SqlValidatorScope scope, EngineType engineType) throws Exception { boolean isAllMatch = true; sourceMeasure.retainAll(measures); if (sourceMeasure.size() < measures.size()) { @@ -367,11 +301,8 @@ public class DataSourceNode extends SemanticNode { return isAllMatch; } - private static List getLinkDataSourcesByJoinRelation( - Set queryDimension, - List measures, - DataSource baseDataSource, - SemanticSchema schema) { + private static List getLinkDataSourcesByJoinRelation(Set queryDimension, + List measures, DataSource baseDataSource, SemanticSchema schema) { Set linkDataSourceName = new HashSet<>(); List linkDataSources = new ArrayList<>(); Set before = new HashSet<>(); @@ -379,13 +310,9 @@ public class DataSourceNode extends SemanticNode { if (!CollectionUtils.isEmpty(schema.getJoinRelations())) { Set visitJoinRelations = new HashSet<>(); List sortedJoinRelation = new ArrayList<>(); - sortJoinRelation( - schema.getJoinRelations(), - baseDataSource.getName(), - visitJoinRelations, - sortedJoinRelation); - schema.getJoinRelations().stream() - .filter(j -> !visitJoinRelations.contains(j.getId())) + sortJoinRelation(schema.getJoinRelations(), baseDataSource.getName(), + visitJoinRelations, sortedJoinRelation); + schema.getJoinRelations().stream().filter(j -> !visitJoinRelations.contains(j.getId())) .forEach(j -> sortedJoinRelation.add(j)); for (JoinRelation joinRelation : sortedJoinRelation) { if (!before.contains(joinRelation.getLeft()) @@ -394,34 +321,26 @@ public class DataSourceNode extends SemanticNode { } boolean isMatch = false; boolean isRight = before.contains(joinRelation.getLeft()); - DataSource other = - isRight - ? schema.getDatasource().get(joinRelation.getRight()) - : schema.getDatasource().get(joinRelation.getLeft()); + DataSource other = isRight ? schema.getDatasource().get(joinRelation.getRight()) + : schema.getDatasource().get(joinRelation.getLeft()); if (!queryDimension.isEmpty()) { - Set linkDimension = - other.getDimensions().stream() - .map(dd -> dd.getName()) - .collect(Collectors.toSet()); + Set linkDimension = other.getDimensions().stream() + .map(dd -> dd.getName()).collect(Collectors.toSet()); other.getIdentifiers().stream().forEach(i -> linkDimension.add(i.getName())); linkDimension.retainAll(queryDimension); if (!linkDimension.isEmpty()) { isMatch = true; } } - Set linkMeasure = - other.getMeasures().stream() - .map(mm -> mm.getName()) - .collect(Collectors.toSet()); + Set linkMeasure = other.getMeasures().stream().map(mm -> mm.getName()) + .collect(Collectors.toSet()); linkMeasure.retainAll(measures); if (!linkMeasure.isEmpty()) { isMatch = true; } if (!isMatch && schema.getDimension().containsKey(other.getName())) { - Set linkDimension = - schema.getDimension().get(other.getName()).stream() - .map(dd -> dd.getName()) - .collect(Collectors.toSet()); + Set linkDimension = schema.getDimension().get(other.getName()).stream() + .map(dd -> dd.getName()).collect(Collectors.toSet()); linkDimension.retainAll(queryDimension); if (!linkDimension.isEmpty()) { isMatch = true; @@ -444,41 +363,30 @@ public class DataSourceNode extends SemanticNode { orders.put(joinRelation.getRight(), 1L); } } - orders.entrySet().stream() - .sorted(Map.Entry.comparingByValue()) - .forEach( - d -> { - linkDataSources.add(schema.getDatasource().get(d.getKey())); - }); + orders.entrySet().stream().sorted(Map.Entry.comparingByValue()).forEach(d -> { + linkDataSources.add(schema.getDatasource().get(d.getKey())); + }); } return linkDataSources; } - private static void sortJoinRelation( - List joinRelations, - String next, - Set visited, - List sortedJoins) { + private static void sortJoinRelation(List joinRelations, String next, + Set visited, List sortedJoins) { for (JoinRelation link : joinRelations) { if (!visited.contains(link.getId())) { if (link.getLeft().equals(next) || link.getRight().equals(next)) { visited.add(link.getId()); sortedJoins.add(link); - sortJoinRelation( - joinRelations, - link.getLeft().equals(next) ? link.getRight() : link.getLeft(), - visited, + sortJoinRelation(joinRelations, + link.getLeft().equals(next) ? link.getRight() : link.getLeft(), visited, sortedJoins); } } } } - private static List getLinkDataSources( - Set baseIdentifiers, - Set queryDimension, - List measures, - DataSource baseDataSource, + private static List getLinkDataSources(Set baseIdentifiers, + Set queryDimension, List measures, DataSource baseDataSource, SemanticSchema schema) { Set linkDataSourceName = new HashSet<>(); List linkDataSources = new ArrayList<>(); @@ -486,18 +394,13 @@ public class DataSourceNode extends SemanticNode { if (entry.getKey().equalsIgnoreCase(baseDataSource.getName())) { continue; } - Long identifierNum = - entry.getValue().getIdentifiers().stream() - .map(i -> i.getName()) - .filter(i -> baseIdentifiers.contains(i)) - .count(); + Long identifierNum = entry.getValue().getIdentifiers().stream().map(i -> i.getName()) + .filter(i -> baseIdentifiers.contains(i)).count(); if (identifierNum > 0) { boolean isMatch = false; if (!queryDimension.isEmpty()) { - Set linkDimension = - entry.getValue().getDimensions().stream() - .map(dd -> dd.getName()) - .collect(Collectors.toSet()); + Set linkDimension = entry.getValue().getDimensions().stream() + .map(dd -> dd.getName()).collect(Collectors.toSet()); entry.getValue().getIdentifiers().stream() .forEach(i -> linkDimension.add(i.getName())); linkDimension.retainAll(queryDimension); @@ -506,10 +409,8 @@ public class DataSourceNode extends SemanticNode { } } if (!measures.isEmpty()) { - Set linkMeasure = - entry.getValue().getMeasures().stream() - .map(mm -> mm.getName()) - .collect(Collectors.toSet()); + Set linkMeasure = entry.getValue().getMeasures().stream() + .map(mm -> mm.getName()).collect(Collectors.toSet()); linkMeasure.retainAll(measures); if (!linkMeasure.isEmpty()) { isMatch = true; @@ -522,10 +423,8 @@ public class DataSourceNode extends SemanticNode { } for (Map.Entry> entry : schema.getDimension().entrySet()) { if (!queryDimension.isEmpty()) { - Set linkDimension = - entry.getValue().stream() - .map(dd -> dd.getName()) - .collect(Collectors.toSet()); + Set linkDimension = entry.getValue().stream().map(dd -> dd.getName()) + .collect(Collectors.toSet()); linkDimension.retainAll(queryDimension); if (!linkDimension.isEmpty()) { linkDataSourceName.add(entry.getKey()); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DimensionNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DimensionNode.java index b32e8955e..6119e33ba 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DimensionNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/DimensionNode.java @@ -17,25 +17,24 @@ public class DimensionNode extends SemanticNode { return buildAs(dimension.getName(), sqlNode); } - public static List expand( - Dimension dimension, SqlValidatorScope scope, EngineType engineType) throws Exception { + public static List expand(Dimension dimension, SqlValidatorScope scope, + EngineType engineType) throws Exception { SqlNode sqlNode = parse(dimension.getExpr(), scope, engineType); return expand(sqlNode, scope); } - public static SqlNode buildName( - Dimension dimension, SqlValidatorScope scope, EngineType engineType) throws Exception { + public static SqlNode buildName(Dimension dimension, SqlValidatorScope scope, + EngineType engineType) throws Exception { return parse(dimension.getName(), scope, engineType); } - public static SqlNode buildExp( - Dimension dimension, SqlValidatorScope scope, EngineType engineType) throws Exception { + public static SqlNode buildExp(Dimension dimension, SqlValidatorScope scope, + EngineType engineType) throws Exception { return parse(dimension.getExpr(), scope, engineType); } - public static SqlNode buildNameAs( - String alias, Dimension dimension, SqlValidatorScope scope, EngineType engineType) - throws Exception { + public static SqlNode buildNameAs(String alias, Dimension dimension, SqlValidatorScope scope, + EngineType engineType) throws Exception { if ("".equals(alias)) { return buildName(dimension, scope, engineType); } @@ -43,16 +42,13 @@ public class DimensionNode extends SemanticNode { return buildAs(alias, sqlNode); } - public static SqlNode buildArray( - Dimension dimension, SqlValidatorScope scope, EngineType engineType) throws Exception { + public static SqlNode buildArray(Dimension dimension, SqlValidatorScope scope, + EngineType engineType) throws Exception { if (Objects.nonNull(dimension.getDataType()) && dimension.getDataType().isArray()) { SqlNode sqlNode = parse(dimension.getExpr(), scope, engineType); if (isIdentifier(sqlNode)) { - return buildAs( - dimension.getName(), - parse( - dimension.getExpr() + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, - scope, + return buildAs(dimension.getName(), + parse(dimension.getExpr() + Constants.DIMENSION_ARRAY_SINGLE_SUFFIX, scope, engineType)); } throw new Exception("array dimension expr should only identify"); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/IdentifyNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/IdentifyNode.java index d64230f90..d42ef4948 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/IdentifyNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/IdentifyNode.java @@ -18,10 +18,8 @@ public class IdentifyNode extends SemanticNode { } public static Set getIdentifyNames(List identifies, Identify.Type type) { - return identifies.stream() - .filter(i -> type.name().equalsIgnoreCase(i.getType())) - .map(i -> i.getName()) - .collect(Collectors.toSet()); + return identifies.stream().filter(i -> type.name().equalsIgnoreCase(i.getType())) + .map(i -> i.getName()).collect(Collectors.toSet()); } public static boolean isForeign(String name, List identifies) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MeasureNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MeasureNode.java index 1ba4a5558..41638d097 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MeasureNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MeasureNode.java @@ -7,29 +7,25 @@ import org.apache.calcite.sql.validate.SqlValidatorScope; public class MeasureNode extends SemanticNode { - public static SqlNode buildNonAgg( - String alias, Measure measure, SqlValidatorScope scope, EngineType engineType) - throws Exception { + public static SqlNode buildNonAgg(String alias, Measure measure, SqlValidatorScope scope, + EngineType engineType) throws Exception { return buildAs(measure.getName(), getExpr(measure, alias, scope, engineType)); } - public static SqlNode buildAgg( - Measure measure, boolean noAgg, SqlValidatorScope scope, EngineType engineType) - throws Exception { + public static SqlNode buildAgg(Measure measure, boolean noAgg, SqlValidatorScope scope, + EngineType engineType) throws Exception { if ((measure.getAgg() == null || measure.getAgg().isEmpty()) || noAgg) { return parse(measure.getName(), scope, engineType); } - return buildAs( - measure.getName(), + return buildAs(measure.getName(), AggFunctionNode.build(measure.getAgg(), measure.getName(), scope, engineType)); } - private static SqlNode getExpr( - Measure measure, String alias, SqlValidatorScope scope, EngineType enginType) - throws Exception { + private static SqlNode getExpr(Measure measure, String alias, SqlValidatorScope scope, + EngineType enginType) throws Exception { if (measure.getExpr() == null) { - return parse( - (alias.isEmpty() ? "" : alias + ".") + measure.getName(), scope, enginType); + return parse((alias.isEmpty() ? "" : alias + ".") + measure.getName(), scope, + enginType); } return parse(measure.getExpr(), scope, enginType); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MetricNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MetricNode.java index c6f897343..6a894452b 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MetricNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/MetricNode.java @@ -22,8 +22,7 @@ public class MetricNode extends SemanticNode { public static SqlNode build(Metric metric, SqlValidatorScope scope, EngineType engineType) throws Exception { - if (metric.getMetricTypeParams() == null - || metric.getMetricTypeParams().getExpr() == null + if (metric.getMetricTypeParams() == null || metric.getMetricTypeParams().getExpr() == null || metric.getMetricTypeParams().getExpr().isEmpty()) { return parse(metric.getName(), scope, engineType); } @@ -32,10 +31,8 @@ public class MetricNode extends SemanticNode { } public static Boolean isMetricField(String name, SemanticSchema schema) { - Optional metric = - schema.getMetrics().stream() - .filter(m -> m.getName().equalsIgnoreCase(name)) - .findFirst(); + Optional metric = schema.getMetrics().stream() + .filter(m -> m.getName().equalsIgnoreCase(name)).findFirst(); return metric.isPresent() && metric.get().getMetricTypeParams().isFieldMetric(); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/SemanticNode.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/SemanticNode.java index 519aa1adc..d2e8826f3 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/SemanticNode.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/node/SemanticNode.java @@ -65,7 +65,7 @@ public abstract class SemanticNode { AGGREGATION_KIND.add(SqlKind.SUM); AGGREGATION_KIND.add(SqlKind.MAX); AGGREGATION_KIND.add(SqlKind.MIN); - AGGREGATION_KIND.add(SqlKind.OTHER_FUNCTION); // more + AGGREGATION_KIND.add(SqlKind.OTHER_FUNCTION); // more AGGREGATION_FUNC.add("sum"); AGGREGATION_FUNC.add("count"); AGGREGATION_FUNC.add("max"); @@ -75,11 +75,9 @@ public abstract class SemanticNode { public static SqlNode parse(String expression, SqlValidatorScope scope, EngineType engineType) throws Exception { - SqlValidatorWithHints sqlValidatorWithHints = - Configuration.getSqlValidatorWithHints( - scope.getValidator().getCatalogReader().getRootSchema(), engineType); - if (Configuration.getSqlAdvisor(sqlValidatorWithHints, engineType) - .getReservedAndKeyWords() + SqlValidatorWithHints sqlValidatorWithHints = Configuration.getSqlValidatorWithHints( + scope.getValidator().getCatalogReader().getRootSchema(), engineType); + if (Configuration.getSqlAdvisor(sqlValidatorWithHints, engineType).getReservedAndKeyWords() .contains(expression.toUpperCase())) { expression = String.format("`%s`", expression); } @@ -93,10 +91,8 @@ public abstract class SemanticNode { public static SqlNode buildAs(String asName, SqlNode sqlNode) throws Exception { SqlAsOperator sqlAsOperator = new SqlAsOperator(); SqlIdentifier sqlIdentifier = new SqlIdentifier(asName, SqlParserPos.ZERO); - return new SqlBasicCall( - sqlAsOperator, - new ArrayList<>(Arrays.asList(sqlNode, sqlIdentifier)), - SqlParserPos.ZERO); + return new SqlBasicCall(sqlAsOperator, + new ArrayList<>(Arrays.asList(sqlNode, sqlIdentifier)), SqlParserPos.ZERO); } public static String getSql(SqlNode sqlNode, EngineType engineType) { @@ -154,17 +150,10 @@ public abstract class SemanticNode { if (table instanceof SqlSelect) { SqlSelect tableSelect = (SqlSelect) table; return tableSelect.getSelectList().stream() - .map( - s -> - (s instanceof SqlIdentifier) - ? ((SqlIdentifier) s).names.get(0) - : (((s instanceof SqlBasicCall) - && s.getKind().equals(SqlKind.AS)) - ? ((SqlBasicCall) s) - .getOperandList() - .get(1) - .toString() - : "")) + .map(s -> (s instanceof SqlIdentifier) ? ((SqlIdentifier) s).names.get(0) + : (((s instanceof SqlBasicCall) && s.getKind().equals(SqlKind.AS)) + ? ((SqlBasicCall) s).getOperandList().get(1).toString() + : "")) .collect(Collectors.toSet()); } return new HashSet<>(); @@ -192,10 +181,8 @@ public abstract class SemanticNode { case AS: SqlBasicCall sqlBasicCall = (SqlBasicCall) sqlNode; if (sqlBasicCall.getOperandList().get(0).getKind().equals(SqlKind.IDENTIFIER)) { - addTableName( - sqlBasicCall.getOperandList().get(0).toString(), - sqlBasicCall.getOperandList().get(1).toString(), - parseInfo); + addTableName(sqlBasicCall.getOperandList().get(0).toString(), + sqlBasicCall.getOperandList().get(1).toString(), parseInfo); } else { sqlVisit(sqlBasicCall.getOperandList().get(0), parseInfo); } @@ -211,12 +198,9 @@ public abstract class SemanticNode { } break; case UNION: - ((SqlBasicCall) sqlNode) - .getOperandList() - .forEach( - node -> { - sqlVisit(node, parseInfo); - }); + ((SqlBasicCall) sqlNode).getOperandList().forEach(node -> { + sqlVisit(node, parseInfo); + }); break; case WITH: SqlWith sqlWith = (SqlWith) sqlNode; @@ -233,12 +217,9 @@ public abstract class SemanticNode { } SqlSelect sqlSelect = (SqlSelect) select; SqlNodeList selectList = sqlSelect.getSelectList(); - selectList - .getList() - .forEach( - list -> { - fieldVisit(list, parseInfo, ""); - }); + selectList.getList().forEach(list -> { + fieldVisit(list, parseInfo, ""); + }); fromVisit(sqlSelect.getFrom(), parseInfo); if (sqlSelect.hasWhere()) { whereVisit((SqlBasicCall) sqlSelect.getWhere(), parseInfo); @@ -248,17 +229,16 @@ public abstract class SemanticNode { } SqlNodeList group = sqlSelect.getGroup(); if (group != null) { - group.forEach( - groupField -> { - if (groupHints.contains(groupField.toString())) { - int groupIdx = Integer.valueOf(groupField.toString()) - 1; - if (selectList.getList().size() > groupIdx) { - fieldVisit(selectList.get(groupIdx), parseInfo, ""); - } - } else { - fieldVisit(groupField, parseInfo, ""); - } - }); + group.forEach(groupField -> { + if (groupHints.contains(groupField.toString())) { + int groupIdx = Integer.valueOf(groupField.toString()) - 1; + if (selectList.getList().size() > groupIdx) { + fieldVisit(selectList.get(groupIdx), parseInfo, ""); + } + } else { + fieldVisit(groupField, parseInfo, ""); + } + }); } } @@ -266,17 +246,15 @@ public abstract class SemanticNode { if (where == null) { return; } - if (where.operandCount() == 2 - && where.operand(0).getKind().equals(SqlKind.IDENTIFIER) + if (where.operandCount() == 2 && where.operand(0).getKind().equals(SqlKind.IDENTIFIER) && where.operand(1).getKind().equals(SqlKind.LITERAL)) { fieldVisit(where.operand(0), parseInfo, ""); return; } // 子查询 - if (where.operandCount() == 2 - && (where.operand(0).getKind().equals(SqlKind.IDENTIFIER) - && (where.operand(1).getKind().equals(SqlKind.SELECT) - || where.operand(1).getKind().equals(SqlKind.ORDER_BY)))) { + if (where.operandCount() == 2 && (where.operand(0).getKind().equals(SqlKind.IDENTIFIER) + && (where.operand(1).getKind().equals(SqlKind.SELECT) + || where.operand(1).getKind().equals(SqlKind.ORDER_BY)))) { fieldVisit(where.operand(0), parseInfo, ""); sqlVisit((SqlNode) (where.operand(1)), parseInfo); return; @@ -331,12 +309,9 @@ public abstract class SemanticNode { } } if (field instanceof SqlNodeList) { - ((SqlNodeList) field) - .getList() - .forEach( - node -> { - fieldVisit(node, parseInfo, ""); - }); + ((SqlNodeList) field).getList().forEach(node -> { + fieldVisit(node, parseInfo, ""); + }); } } @@ -421,10 +396,7 @@ public abstract class SemanticNode { return parseInfo; } - public static SqlNode optimize( - SqlValidatorScope scope, - SemanticSchema schema, - SqlNode sqlNode, + public static SqlNode optimize(SqlValidatorScope scope, SemanticSchema schema, SqlNode sqlNode, EngineType engineType) { try { HepProgramBuilder hepProgramBuilder = new HepProgramBuilder(); @@ -433,16 +405,13 @@ public abstract class SemanticNode { new FilterToGroupScanRule(FilterToGroupScanRule.DEFAULT, schema)); RelOptPlanner relOptPlanner = new HepPlanner(hepProgramBuilder.build()); RelToSqlConverter converter = new RelToSqlConverter(sqlDialect); - SqlValidator sqlValidator = - Configuration.getSqlValidator( - scope.getValidator().getCatalogReader().getRootSchema(), engineType); - SqlToRelConverter sqlToRelConverter = - Configuration.getSqlToRelConverter( - scope, sqlValidator, relOptPlanner, engineType); + SqlValidator sqlValidator = Configuration.getSqlValidator( + scope.getValidator().getCatalogReader().getRootSchema(), engineType); + SqlToRelConverter sqlToRelConverter = Configuration.getSqlToRelConverter(scope, + sqlValidator, relOptPlanner, engineType); RelNode sqlRel = sqlToRelConverter.convertQuery(sqlValidator.validate(sqlNode), false, true).rel; - log.debug( - "RelNode optimize {}", + log.debug("RelNode optimize {}", SemanticNode.getSql(converter.visitRoot(sqlRel).asStatement(), engineType)); relOptPlanner.setRoot(sqlRel); RelNode relNode = relOptPlanner.findBestExp(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/optimizer/FilterToGroupScanRule.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/optimizer/FilterToGroupScanRule.java index ed77131a8..b812f8f8c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/optimizer/FilterToGroupScanRule.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/optimizer/FilterToGroupScanRule.java @@ -30,29 +30,15 @@ import java.util.Optional; public class FilterToGroupScanRule extends RelRule implements TransformationRule { public static FilterTableScanRule.Config DEFAULT = - FilterTableScanRule.Config.DEFAULT - .withOperandSupplier( - (b0) -> { - return b0.operand(LogicalFilter.class) - .oneInput( - (b1) -> { - return b1.operand(LogicalProject.class) - .oneInput( - (b2) -> { - return b2.operand( - LogicalAggregate - .class) - .oneInput( - (b3) -> { - return b3.operand( - LogicalProject - .class) - .anyInputs(); - }); - }); - }); - }) - .as(FilterTableScanRule.Config.class); + FilterTableScanRule.Config.DEFAULT.withOperandSupplier((b0) -> { + return b0.operand(LogicalFilter.class).oneInput((b1) -> { + return b1.operand(LogicalProject.class).oneInput((b2) -> { + return b2.operand(LogicalAggregate.class).oneInput((b3) -> { + return b3.operand(LogicalProject.class).anyInputs(); + }); + }); + }); + }).as(FilterTableScanRule.Config.class); private SemanticSchema semanticSchema; @@ -75,19 +61,16 @@ public class FilterToGroupScanRule extends RelRule implements Transforma Project project0 = (Project) call.rel(1); Project project1 = (Project) call.rel(3); Aggregate logicalAggregate = (Aggregate) call.rel(2); - Optional> isIn = - project1.getNamedProjects().stream() - .filter(i -> i.right.equalsIgnoreCase(minMax.getLeft())) - .findFirst(); + Optional> isIn = project1.getNamedProjects().stream() + .filter(i -> i.right.equalsIgnoreCase(minMax.getLeft())).findFirst(); if (!isIn.isPresent()) { return; } RelBuilder relBuilder = call.builder(); relBuilder.push(project1); - RexNode addPartitionCondition = - getRexNodeByTimeRange( - relBuilder, minMax.getLeft(), minMax.getMiddle(), minMax.getRight()); + RexNode addPartitionCondition = getRexNodeByTimeRange(relBuilder, minMax.getLeft(), + minMax.getMiddle(), minMax.getRight()); relBuilder.filter(new RexNode[] {addPartitionCondition}); relBuilder.project(project1.getProjects()); ImmutableBitSet newGroupSet = logicalAggregate.getGroupSet(); @@ -97,13 +80,8 @@ public class FilterToGroupScanRule extends RelRule implements Transforma Iterator var = logicalAggregate.getAggCallList().iterator(); while (var.hasNext()) { AggregateCall aggCall = (AggregateCall) var.next(); - newAggCalls.add( - aggCall.adaptTo( - project1, - aggCall.getArgList(), - aggCall.filterArg, - groupCount, - newGroupCount)); + newAggCalls.add(aggCall.adaptTo(project1, aggCall.getArgList(), aggCall.filterArg, + groupCount, newGroupCount)); } relBuilder.aggregate(relBuilder.groupKey(newGroupSet), newAggCalls); relBuilder.project(project0.getProjects()); @@ -111,17 +89,12 @@ public class FilterToGroupScanRule extends RelRule implements Transforma call.transformTo(relBuilder.build()); } - private RexNode getRexNodeByTimeRange( - RelBuilder relBuilder, String dateField, String start, String end) { - return relBuilder.call( - SqlStdOperatorTable.AND, - relBuilder.call( - SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, - relBuilder.field(dateField), - relBuilder.literal(start)), - relBuilder.call( - SqlStdOperatorTable.LESS_THAN_OR_EQUAL, - relBuilder.field(dateField), + private RexNode getRexNodeByTimeRange(RelBuilder relBuilder, String dateField, String start, + String end) { + return relBuilder.call(SqlStdOperatorTable.AND, + relBuilder.call(SqlStdOperatorTable.GREATER_THAN_OR_EQUAL, + relBuilder.field(dateField), relBuilder.literal(start)), + relBuilder.call(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, relBuilder.field(dateField), relBuilder.literal(end))); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/FilterRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/FilterRender.java index 9372000d5..1c815e12c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/FilterRender.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/FilterRender.java @@ -27,13 +27,8 @@ import java.util.stream.Collectors; public class FilterRender extends Renderer { @Override - public void render( - MetricQueryParam metricCommand, - List dataSources, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { + public void render(MetricQueryParam metricCommand, List dataSources, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { TableView tableView = super.tableView; SqlNode filterNode = null; List queryMetrics = new ArrayList<>(metricCommand.getMetrics()); @@ -49,14 +44,8 @@ public class FilterRender extends Renderer { Set dimensions = new HashSet<>(); Set metrics = new HashSet<>(); for (DataSource dataSource : dataSources) { - SourceRender.whereDimMetric( - fieldWhere, - metricCommand.getMetrics(), - metricCommand.getDimensions(), - dataSource, - schema, - dimensions, - metrics); + SourceRender.whereDimMetric(fieldWhere, metricCommand.getMetrics(), + metricCommand.getDimensions(), dataSource, schema, dimensions, metrics); } queryMetrics.addAll(metrics); queryDimensions.addAll(dimensions); @@ -71,8 +60,7 @@ public class FilterRender extends Renderer { continue; } if (optionalMetric.isPresent()) { - tableView - .getMeasure() + tableView.getMeasure() .add(MetricNode.build(optionalMetric.get(), scope, engineType)); } else { tableView.getMeasure().add(SemanticNode.parse(metric, scope, engineType)); @@ -80,9 +68,8 @@ public class FilterRender extends Renderer { } if (filterNode != null) { TableView filterView = new TableView(); - filterView.setTable( - SemanticNode.buildAs( - Constants.DATASOURCE_TABLE_FILTER_PREFIX, tableView.build())); + filterView.setTable(SemanticNode.buildAs(Constants.DATASOURCE_TABLE_FILTER_PREFIX, + tableView.build())); filterView.getFilter().add(filterNode); filterView.getMeasure().add(SqlIdentifier.star(SqlParserPos.ZERO)); super.tableView = filterView; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java index 20bf97819..6bc729efe 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/JoinRender.java @@ -48,13 +48,8 @@ import java.util.stream.Collectors; public class JoinRender extends Renderer { @Override - public void render( - MetricQueryParam metricCommand, - List dataSources, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { + public void render(MetricQueryParam metricCommand, List dataSources, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { String queryWhere = metricCommand.getWhere(); EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); @@ -82,14 +77,8 @@ public class JoinRender extends Renderer { final Set filterMetrics = new HashSet<>(); final List queryDimension = new ArrayList<>(); final List queryMetrics = new ArrayList<>(); - SourceRender.whereDimMetric( - fieldWhere, - queryMetrics, - queryDimension, - dataSource, - schema, - filterDimensions, - filterMetrics); + SourceRender.whereDimMetric(fieldWhere, queryMetrics, queryDimension, dataSource, + schema, filterDimensions, filterMetrics); List reqMetric = new ArrayList<>(metricCommand.getMetrics()); reqMetric.addAll(filterMetrics); reqMetric = uniqList(reqMetric); @@ -98,33 +87,14 @@ public class JoinRender extends Renderer { reqDimension.addAll(filterDimensions); reqDimension = uniqList(reqDimension); - Set sourceMeasure = - dataSource.getMeasures().stream() - .map(mm -> mm.getName()) - .collect(Collectors.toSet()); - doMetric( - innerSelect, - filterView, - queryMetrics, - reqMetric, - dataSource, - sourceMeasure, - scope, - schema, - nonAgg); - Set dimension = - dataSource.getDimensions().stream() - .map(dd -> dd.getName()) - .collect(Collectors.toSet()); - doDimension( - innerSelect, - filterDimension, - queryDimension, - reqDimension, - dataSource, - dimension, - scope, - schema); + Set sourceMeasure = dataSource.getMeasures().stream().map(mm -> mm.getName()) + .collect(Collectors.toSet()); + doMetric(innerSelect, filterView, queryMetrics, reqMetric, dataSource, sourceMeasure, + scope, schema, nonAgg); + Set dimension = dataSource.getDimensions().stream().map(dd -> dd.getName()) + .collect(Collectors.toSet()); + doDimension(innerSelect, filterDimension, queryDimension, reqDimension, dataSource, + dimension, scope, schema); List primary = new ArrayList<>(); for (Identify identify : dataSource.getIdentifiers()) { primary.add(identify.getName()); @@ -135,16 +105,8 @@ public class JoinRender extends Renderer { List dataSourceWhere = new ArrayList<>(fieldWhere); addZipperField(dataSource, dataSourceWhere); TableView tableView = - SourceRender.renderOne( - "", - dataSourceWhere, - queryMetrics, - queryDimension, - metricCommand.getWhere(), - dataSources.get(i), - scope, - schema, - true); + SourceRender.renderOne("", dataSourceWhere, queryMetrics, queryDimension, + metricCommand.getWhere(), dataSources.get(i), scope, schema, true); log.info("tableView {}", StringUtils.normalizeSpace(tableView.getTable().toString())); String alias = Constants.JOIN_TABLE_PREFIX + dataSource.getName(); tableView.setAlias(alias); @@ -165,8 +127,8 @@ public class JoinRender extends Renderer { innerView.getMeasure().add(entry.getValue()); } innerView.setTable(left); - filterView.setTable( - SemanticNode.buildAs(Constants.JOIN_TABLE_OUT_PREFIX, innerView.build())); + filterView + .setTable(SemanticNode.buildAs(Constants.JOIN_TABLE_OUT_PREFIX, innerView.build())); if (!filterDimension.isEmpty()) { for (String d : getQueryDimension(filterDimension, queryAllDimension, whereFields)) { if (nonAgg) { @@ -179,17 +141,10 @@ public class JoinRender extends Renderer { super.tableView = filterView; } - private void doMetric( - Map innerSelect, - TableView filterView, - List queryMetrics, - List reqMetrics, - DataSource dataSource, - Set sourceMeasure, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { + private void doMetric(Map innerSelect, TableView filterView, + List queryMetrics, List reqMetrics, DataSource dataSource, + Set sourceMeasure, SqlValidatorScope scope, SemanticSchema schema, + boolean nonAgg) throws Exception { String alias = Constants.JOIN_TABLE_PREFIX + dataSource.getName(); EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); @@ -200,38 +155,21 @@ public class JoinRender extends Renderer { if (!metricNode.getNonAggNode().isEmpty()) { for (String measure : metricNode.getNonAggNode().keySet()) { - innerSelect.put( - measure, - SemanticNode.buildAs( - measure, - SemanticNode.parse( - alias + "." + measure, scope, engineType))); + innerSelect.put(measure, SemanticNode.buildAs(measure, + SemanticNode.parse(alias + "." + measure, scope, engineType))); } } if (metricNode.getAggFunction() != null && !metricNode.getAggFunction().isEmpty()) { for (Map.Entry entry : metricNode.getAggFunction().entrySet()) { if (metricNode.getNonAggNode().containsKey(entry.getKey())) { if (nonAgg) { - filterView - .getMeasure() - .add( - SemanticNode.buildAs( - entry.getKey(), - SemanticNode.parse( - entry.getKey(), - scope, - engineType))); + filterView.getMeasure().add(SemanticNode.buildAs(entry.getKey(), + SemanticNode.parse(entry.getKey(), scope, engineType))); } else { - filterView - .getMeasure() - .add( - SemanticNode.buildAs( - entry.getKey(), - AggFunctionNode.build( - entry.getValue(), - entry.getKey(), - scope, - engineType))); + filterView.getMeasure() + .add(SemanticNode.buildAs(entry.getKey(), + AggFunctionNode.build(entry.getValue(), + entry.getKey(), scope, engineType))); } } } @@ -240,15 +178,9 @@ public class JoinRender extends Renderer { } } - private void doDimension( - Map innerSelect, - Set filterDimension, - List queryDimension, - List reqDimensions, - DataSource dataSource, - Set dimension, - SqlValidatorScope scope, - SemanticSchema schema) + private void doDimension(Map innerSelect, Set filterDimension, + List queryDimension, List reqDimensions, DataSource dataSource, + Set dimension, SqlValidatorScope scope, SemanticSchema schema) throws Exception { String alias = Constants.JOIN_TABLE_PREFIX + dataSource.getName(); EngineType engineType = @@ -257,44 +189,32 @@ public class JoinRender extends Renderer { if (getMatchDimension(schema, dimension, dataSource, d, queryDimension)) { if (d.contains(Constants.DIMENSION_IDENTIFY)) { String[] identifyDimension = d.split(Constants.DIMENSION_IDENTIFY); - innerSelect.put( - d, - SemanticNode.buildAs( - d, - SemanticNode.parse( - alias + "." + identifyDimension[1], - scope, - engineType))); + innerSelect.put(d, SemanticNode.buildAs(d, SemanticNode + .parse(alias + "." + identifyDimension[1], scope, engineType))); } else { - innerSelect.put( - d, - SemanticNode.buildAs( - d, SemanticNode.parse(alias + "." + d, scope, engineType))); + innerSelect.put(d, SemanticNode.buildAs(d, + SemanticNode.parse(alias + "." + d, scope, engineType))); } filterDimension.add(d); } } } - private Set getQueryDimension( - Set filterDimension, Set queryAllDimension, Set whereFields) { + private Set getQueryDimension(Set filterDimension, + Set queryAllDimension, Set whereFields) { return filterDimension.stream() .filter(d -> queryAllDimension.contains(d) || whereFields.contains(d)) .collect(Collectors.toSet()); } - private boolean getMatchMetric( - SemanticSchema schema, Set sourceMeasure, String m, List queryMetrics) { - Optional metric = - schema.getMetrics().stream() - .filter(mm -> mm.getName().equalsIgnoreCase(m)) - .findFirst(); + private boolean getMatchMetric(SemanticSchema schema, Set sourceMeasure, String m, + List queryMetrics) { + Optional metric = schema.getMetrics().stream() + .filter(mm -> mm.getName().equalsIgnoreCase(m)).findFirst(); boolean isAdd = false; if (metric.isPresent()) { - Set metricMeasures = - metric.get().getMetricTypeParams().getMeasures().stream() - .map(me -> me.getName()) - .collect(Collectors.toSet()); + Set metricMeasures = metric.get().getMetricTypeParams().getMeasures().stream() + .map(me -> me.getName()).collect(Collectors.toSet()); if (sourceMeasure.containsAll(metricMeasures)) { isAdd = true; } @@ -308,12 +228,8 @@ public class JoinRender extends Renderer { return isAdd; } - private boolean getMatchDimension( - SemanticSchema schema, - Set sourceDimension, - DataSource dataSource, - String d, - List queryDimension) { + private boolean getMatchDimension(SemanticSchema schema, Set sourceDimension, + DataSource dataSource, String d, List queryDimension) { String oriDimension = d; boolean isAdd = false; if (d.contains(Constants.DIMENSION_IDENTIFY)) { @@ -345,15 +261,9 @@ public class JoinRender extends Renderer { return SemanticNode.getTable(tableView.getTable()); } - private SqlNode buildJoin( - SqlNode left, - TableView leftTable, - TableView tableView, - Map before, - DataSource dataSource, - SemanticSchema schema, - SqlValidatorScope scope) - throws Exception { + private SqlNode buildJoin(SqlNode left, TableView leftTable, TableView tableView, + Map before, DataSource dataSource, SemanticSchema schema, + SqlValidatorScope scope) throws Exception { EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); SqlNode condition = @@ -367,53 +277,37 @@ public class JoinRender extends Renderer { condition = joinRelationCondition; } if (Materialization.TimePartType.ZIPPER.equals(leftTable.getDataSource().getTimePartType()) - || Materialization.TimePartType.ZIPPER.equals( - tableView.getDataSource().getTimePartType())) { + || Materialization.TimePartType.ZIPPER + .equals(tableView.getDataSource().getTimePartType())) { SqlNode zipperCondition = getZipperCondition(leftTable, tableView, dataSource, schema, scope); if (Objects.nonNull(joinRelationCondition)) { - condition = - new SqlBasicCall( - SqlStdOperatorTable.AND, - new ArrayList<>( - Arrays.asList(zipperCondition, joinRelationCondition)), - SqlParserPos.ZERO, - null); + condition = new SqlBasicCall(SqlStdOperatorTable.AND, + new ArrayList<>(Arrays.asList(zipperCondition, joinRelationCondition)), + SqlParserPos.ZERO, null); } else { condition = zipperCondition; } } - return new SqlJoin( - SqlParserPos.ZERO, - left, - SqlLiteral.createBoolean(false, SqlParserPos.ZERO), - sqlLiteral, + return new SqlJoin(SqlParserPos.ZERO, left, + SqlLiteral.createBoolean(false, SqlParserPos.ZERO), sqlLiteral, SemanticNode.buildAs(tableView.getAlias(), getTable(tableView, scope)), - SqlLiteral.createSymbol(JoinConditionType.ON, SqlParserPos.ZERO), - condition); + SqlLiteral.createSymbol(JoinConditionType.ON, SqlParserPos.ZERO), condition); } - private JoinRelation getMatchJoinRelation( - Map before, TableView tableView, SemanticSchema schema) { + private JoinRelation getMatchJoinRelation(Map before, TableView tableView, + SemanticSchema schema) { JoinRelation matchJoinRelation = JoinRelation.builder().build(); if (!CollectionUtils.isEmpty(schema.getJoinRelations())) { for (JoinRelation joinRelation : schema.getJoinRelations()) { if (joinRelation.getRight().equalsIgnoreCase(tableView.getDataSource().getName()) && before.containsKey(joinRelation.getLeft())) { - matchJoinRelation.setJoinCondition( - joinRelation.getJoinCondition().stream() - .map( - r -> - Triple.of( - before.get(joinRelation.getLeft()) - + "." - + r.getLeft(), - r.getMiddle(), - tableView.getAlias() - + "." - + r.getRight())) - .collect(Collectors.toList())); + matchJoinRelation.setJoinCondition(joinRelation.getJoinCondition().stream() + .map(r -> Triple.of( + before.get(joinRelation.getLeft()) + "." + r.getLeft(), + r.getMiddle(), tableView.getAlias() + "." + r.getRight())) + .collect(Collectors.toList())); matchJoinRelation.setJoinType(joinRelation.getJoinType()); } } @@ -421,46 +315,29 @@ public class JoinRender extends Renderer { return matchJoinRelation; } - private SqlNode getCondition( - JoinRelation joinRelation, SqlValidatorScope scope, EngineType engineType) - throws Exception { + private SqlNode getCondition(JoinRelation joinRelation, SqlValidatorScope scope, + EngineType engineType) throws Exception { SqlNode condition = null; for (Triple con : joinRelation.getJoinCondition()) { List ons = new ArrayList<>(); ons.add(SemanticNode.parse(con.getLeft(), scope, engineType)); ons.add(SemanticNode.parse(con.getRight(), scope, engineType)); if (Objects.isNull(condition)) { - condition = - new SqlBasicCall( - SemanticNode.getBinaryOperator(con.getMiddle()), - ons, - SqlParserPos.ZERO, - null); + condition = new SqlBasicCall(SemanticNode.getBinaryOperator(con.getMiddle()), ons, + SqlParserPos.ZERO, null); continue; } - SqlNode addCondition = - new SqlBasicCall( - SemanticNode.getBinaryOperator(con.getMiddle()), - ons, - SqlParserPos.ZERO, - null); - condition = - new SqlBasicCall( - SqlStdOperatorTable.AND, - new ArrayList<>(Arrays.asList(condition, addCondition)), - SqlParserPos.ZERO, - null); + SqlNode addCondition = new SqlBasicCall(SemanticNode.getBinaryOperator(con.getMiddle()), + ons, SqlParserPos.ZERO, null); + condition = new SqlBasicCall(SqlStdOperatorTable.AND, + new ArrayList<>(Arrays.asList(condition, addCondition)), SqlParserPos.ZERO, + null); } return condition; } - private SqlNode getCondition( - TableView left, - TableView right, - DataSource dataSource, - SemanticSchema schema, - SqlValidatorScope scope, - EngineType engineType) + private SqlNode getCondition(TableView left, TableView right, DataSource dataSource, + SemanticSchema schema, SqlValidatorScope scope, EngineType engineType) throws Exception { Set selectLeft = SemanticNode.getSelect(left.getTable()); @@ -491,22 +368,15 @@ public class JoinRender extends Renderer { } SqlNode addCondition = new SqlBasicCall(SqlStdOperatorTable.EQUALS, ons, SqlParserPos.ZERO, null); - condition = - new SqlBasicCall( - SqlStdOperatorTable.AND, - new ArrayList<>(Arrays.asList(condition, addCondition)), - SqlParserPos.ZERO, - null); + condition = new SqlBasicCall(SqlStdOperatorTable.AND, + new ArrayList<>(Arrays.asList(condition, addCondition)), SqlParserPos.ZERO, + null); } return condition; } - private static void joinOrder( - int cnt, - String id, - Map> next, - Queue orders, - Map visited) { + private static void joinOrder(int cnt, String id, Map> next, + Queue orders, Map visited) { visited.put(id, true); orders.add(id); if (orders.size() >= cnt) { @@ -528,97 +398,62 @@ public class JoinRender extends Renderer { if (Materialization.TimePartType.ZIPPER.equals(dataSource.getTimePartType())) { dataSource.getDimensions().stream() .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - .forEach( - t -> { - if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_END) - && !fields.contains(t.getName())) { - fields.add(t.getName()); - } - if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_START) - && !fields.contains(t.getName())) { - fields.add(t.getName()); - } - }); + .forEach(t -> { + if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_END) + && !fields.contains(t.getName())) { + fields.add(t.getName()); + } + if (t.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_START) + && !fields.contains(t.getName())) { + fields.add(t.getName()); + } + }); } } - private SqlNode getZipperCondition( - TableView left, - TableView right, - DataSource dataSource, - SemanticSchema schema, - SqlValidatorScope scope) - throws Exception { + private SqlNode getZipperCondition(TableView left, TableView right, DataSource dataSource, + SemanticSchema schema, SqlValidatorScope scope) throws Exception { if (Materialization.TimePartType.ZIPPER.equals(left.getDataSource().getTimePartType()) - && Materialization.TimePartType.ZIPPER.equals( - right.getDataSource().getTimePartType())) { + && Materialization.TimePartType.ZIPPER + .equals(right.getDataSource().getTimePartType())) { throw new Exception("not support two zipper table"); } SqlNode condition = null; - Optional leftTime = - left.getDataSource().getDimensions().stream() - .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - .findFirst(); - Optional rightTime = - right.getDataSource().getDimensions().stream() - .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - .findFirst(); + Optional leftTime = left.getDataSource().getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) + .findFirst(); + Optional rightTime = right.getDataSource().getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) + .findFirst(); if (leftTime.isPresent() && rightTime.isPresent()) { String startTime = ""; String endTime = ""; String dateTime = ""; - Optional startTimeOp = - (Materialization.TimePartType.ZIPPER.equals( - left.getDataSource().getTimePartType()) - ? left - : right) - .getDataSource().getDimensions().stream() - .filter( - d -> - Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase( - d.getType())) - .filter( - d -> - d.getName() - .startsWith( - Constants - .MATERIALIZATION_ZIPPER_START)) - .findFirst(); - Optional endTimeOp = - (Materialization.TimePartType.ZIPPER.equals( - left.getDataSource().getTimePartType()) - ? left - : right) - .getDataSource().getDimensions().stream() - .filter( - d -> - Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase( - d.getType())) - .filter( - d -> - d.getName() - .startsWith( - Constants - .MATERIALIZATION_ZIPPER_END)) - .findFirst(); + Optional startTimeOp = (Materialization.TimePartType.ZIPPER + .equals(left.getDataSource().getTimePartType()) ? left : right).getDataSource() + .getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME + .equalsIgnoreCase(d.getType())) + .filter(d -> d.getName() + .startsWith(Constants.MATERIALIZATION_ZIPPER_START)) + .findFirst(); + Optional endTimeOp = (Materialization.TimePartType.ZIPPER + .equals(left.getDataSource().getTimePartType()) ? left : right).getDataSource() + .getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME + .equalsIgnoreCase(d.getType())) + .filter(d -> d.getName() + .startsWith(Constants.MATERIALIZATION_ZIPPER_END)) + .findFirst(); if (startTimeOp.isPresent() && endTimeOp.isPresent()) { - TableView zipper = - Materialization.TimePartType.ZIPPER.equals( - left.getDataSource().getTimePartType()) - ? left - : right; - TableView partMetric = - Materialization.TimePartType.ZIPPER.equals( - left.getDataSource().getTimePartType()) - ? right - : left; - Optional partTime = - Materialization.TimePartType.ZIPPER.equals( - left.getDataSource().getTimePartType()) - ? rightTime - : leftTime; + TableView zipper = Materialization.TimePartType.ZIPPER + .equals(left.getDataSource().getTimePartType()) ? left : right; + TableView partMetric = Materialization.TimePartType.ZIPPER + .equals(left.getDataSource().getTimePartType()) ? right : left; + Optional partTime = Materialization.TimePartType.ZIPPER + .equals(left.getDataSource().getTimePartType()) ? rightTime : leftTime; startTime = zipper.getAlias() + "." + startTimeOp.get().getName(); endTime = zipper.getAlias() + "." + endTimeOp.get().getName(); dateTime = partMetric.getAlias() + "." + partTime.get().getName(); @@ -626,36 +461,18 @@ public class JoinRender extends Renderer { EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); ArrayList operandList = - new ArrayList<>( - Arrays.asList( - SemanticNode.parse(endTime, scope, engineType), - SemanticNode.parse(dateTime, scope, engineType))); - condition = - new SqlBasicCall( - SqlStdOperatorTable.AND, - new ArrayList( - Arrays.asList( - new SqlBasicCall( - SqlStdOperatorTable.LESS_THAN_OR_EQUAL, - new ArrayList( - Arrays.asList( - SemanticNode.parse( - startTime, - scope, - engineType), - SemanticNode.parse( - dateTime, - scope, - engineType))), - SqlParserPos.ZERO, - null), - new SqlBasicCall( - SqlStdOperatorTable.GREATER_THAN, - operandList, - SqlParserPos.ZERO, - null))), - SqlParserPos.ZERO, - null); + new ArrayList<>(Arrays.asList(SemanticNode.parse(endTime, scope, engineType), + SemanticNode.parse(dateTime, scope, engineType))); + condition = new SqlBasicCall(SqlStdOperatorTable.AND, + new ArrayList(Arrays.asList( + new SqlBasicCall(SqlStdOperatorTable.LESS_THAN_OR_EQUAL, + new ArrayList(Arrays.asList( + SemanticNode.parse(startTime, scope, engineType), + SemanticNode.parse(dateTime, scope, engineType))), + SqlParserPos.ZERO, null), + new SqlBasicCall(SqlStdOperatorTable.GREATER_THAN, operandList, + SqlParserPos.ZERO, null))), + SqlParserPos.ZERO, null); } return condition; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/OutputRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/OutputRender.java index fb34d5839..92dffeb9d 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/OutputRender.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/OutputRender.java @@ -23,13 +23,8 @@ import java.util.List; public class OutputRender extends Renderer { @Override - public void render( - MetricQueryParam metricCommand, - List dataSources, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { + public void render(MetricQueryParam metricCommand, List dataSources, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { TableView selectDataSet = super.tableView; EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); @@ -53,12 +48,9 @@ public class OutputRender extends Renderer { List orderList = new ArrayList<>(); for (ColumnOrder columnOrder : metricCommand.getOrder()) { if (SqlStdOperatorTable.DESC.getName().equalsIgnoreCase(columnOrder.getOrder())) { - orderList.add( - SqlStdOperatorTable.DESC.createCall( - SqlParserPos.ZERO, - new SqlNode[] { - SemanticNode.parse(columnOrder.getCol(), scope, engineType) - })); + orderList.add(SqlStdOperatorTable.DESC.createCall(SqlParserPos.ZERO, + new SqlNode[] {SemanticNode.parse(columnOrder.getCol(), scope, + engineType)})); } else { orderList.add(SemanticNode.parse(columnOrder.getCol(), scope, engineType)); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/SourceRender.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/SourceRender.java index c73a394f6..29c990525 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/SourceRender.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/calcite/sql/render/SourceRender.java @@ -42,16 +42,9 @@ import static com.tencent.supersonic.headless.core.translator.calcite.s2sql.Cons @Slf4j public class SourceRender extends Renderer { - public static TableView renderOne( - String alias, - List fieldWheres, - List reqMetrics, - List reqDimensions, - String queryWhere, - DataSource datasource, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) + public static TableView renderOne(String alias, List fieldWheres, + List reqMetrics, List reqDimensions, String queryWhere, + DataSource datasource, SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { TableView dataSet = new TableView(); @@ -63,29 +56,14 @@ public class SourceRender extends Renderer { if (!fieldWhere.isEmpty()) { Set dimensions = new HashSet<>(); Set metrics = new HashSet<>(); - whereDimMetric( - fieldWhere, - queryMetrics, - queryDimensions, - datasource, - schema, - dimensions, - metrics); + whereDimMetric(fieldWhere, queryMetrics, queryDimensions, datasource, schema, + dimensions, metrics); queryMetrics.addAll(metrics); queryMetrics = uniqList(queryMetrics); queryDimensions.addAll(dimensions); queryDimensions = uniqList(queryDimensions); - mergeWhere( - fieldWhere, - dataSet, - output, - queryMetrics, - queryDimensions, - extendFields, - datasource, - scope, - schema, - nonAgg); + mergeWhere(fieldWhere, dataSet, output, queryMetrics, queryDimensions, extendFields, + datasource, scope, schema, nonAgg); } addTimeDimension(datasource, queryDimensions); for (String metric : queryMetrics) { @@ -109,18 +87,11 @@ public class SourceRender extends Renderer { && queryDimensions.contains(dimension.split(Constants.DIMENSION_IDENTIFY)[1])) { continue; } - buildDimension( - dimension.contains(Constants.DIMENSION_IDENTIFY) ? dimension : "", + buildDimension(dimension.contains(Constants.DIMENSION_IDENTIFY) ? dimension : "", dimension.contains(Constants.DIMENSION_IDENTIFY) ? dimension.split(Constants.DIMENSION_IDENTIFY)[1] : dimension, - datasource, - schema, - nonAgg, - extendFields, - dataSet, - output, - scope); + datasource, schema, nonAgg, extendFields, dataSet, output, scope); } output.setMeasure(deduplicateNode(output.getMeasure())); @@ -129,12 +100,8 @@ public class SourceRender extends Renderer { SqlNode tableNode = DataSourceNode.buildExtend(datasource, extendFields, scope); dataSet.setTable(tableNode); output.setTable( - SemanticNode.buildAs( - Constants.DATASOURCE_TABLE_OUT_PREFIX - + datasource.getName() - + "_" - + UUID.randomUUID().toString().substring(32), - dataSet.build())); + SemanticNode.buildAs(Constants.DATASOURCE_TABLE_OUT_PREFIX + datasource.getName() + + "_" + UUID.randomUUID().toString().substring(32), dataSet.build())); return output; } @@ -148,8 +115,7 @@ public class SourceRender extends Renderer { return uniqueElements; } - private static boolean containsElement( - List list, SqlNode element) { // 检查List中是否含有某element + private static boolean containsElement(List list, SqlNode element) { // 检查List中是否含有某element for (SqlNode i : list) { if (i.equalsDeep(element, Litmus.IGNORE)) { return true; @@ -158,17 +124,9 @@ public class SourceRender extends Renderer { return false; } - private static void buildDimension( - String alias, - String dimension, - DataSource datasource, - SemanticSchema schema, - boolean nonAgg, - Map extendFields, - TableView dataSet, - TableView output, - SqlValidatorScope scope) - throws Exception { + private static void buildDimension(String alias, String dimension, DataSource datasource, + SemanticSchema schema, boolean nonAgg, Map extendFields, + TableView dataSet, TableView output, SqlValidatorScope scope) throws Exception { List dimensionList = schema.getDimension().get(datasource.getName()); EngineType engineType = EngineType.fromString(schema.getSemanticModel().getDatabase().getType()); @@ -197,10 +155,8 @@ public class SourceRender extends Renderer { } } if (!isAdd) { - Optional identify = - datasource.getIdentifiers().stream() - .filter(i -> i.getName().equalsIgnoreCase(dimension)) - .findFirst(); + Optional identify = datasource.getIdentifiers().stream() + .filter(i -> i.getName().equalsIgnoreCase(dimension)).findFirst(); if (identify.isPresent()) { if (nonAgg) { dataSet.getMeasure() @@ -238,24 +194,17 @@ public class SourceRender extends Renderer { if (dimension.getDataType().isArray()) { if (Objects.nonNull(dimension.getExt()) && dimension.getExt().containsKey(DIMENSION_DELIMITER)) { - extendFields.put( - dimension.getExpr(), (String) dimension.getExt().get(DIMENSION_DELIMITER)); + extendFields.put(dimension.getExpr(), + (String) dimension.getExt().get(DIMENSION_DELIMITER)); } else { extendFields.put(dimension.getExpr(), ""); } } } - private static List getWhereMeasure( - List fields, - List queryMetrics, - List queryDimensions, - Map extendFields, - DataSource datasource, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { + private static List getWhereMeasure(List fields, List queryMetrics, + List queryDimensions, Map extendFields, DataSource datasource, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { Iterator iterator = fields.iterator(); List whereNode = new ArrayList<>(); EngineType engineType = @@ -295,40 +244,19 @@ public class SourceRender extends Renderer { return whereNode; } - private static void mergeWhere( - List fields, - TableView dataSet, - TableView outputSet, - List queryMetrics, - List queryDimensions, - Map extendFields, - DataSource datasource, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { - List whereNode = - getWhereMeasure( - fields, - queryMetrics, - queryDimensions, - extendFields, - datasource, - scope, - schema, - nonAgg); + private static void mergeWhere(List fields, TableView dataSet, TableView outputSet, + List queryMetrics, List queryDimensions, + Map extendFields, DataSource datasource, SqlValidatorScope scope, + SemanticSchema schema, boolean nonAgg) throws Exception { + List whereNode = getWhereMeasure(fields, queryMetrics, queryDimensions, + extendFields, datasource, scope, schema, nonAgg); dataSet.getMeasure().addAll(whereNode); // getWhere(outputSet,fields,queryMetrics,queryDimensions,datasource,scope,schema); } - public static void whereDimMetric( - List fields, - List queryMetrics, - List queryDimensions, - DataSource datasource, - SemanticSchema schema, - Set dimensions, - Set metrics) { + public static void whereDimMetric(List fields, List queryMetrics, + List queryDimensions, DataSource datasource, SemanticSchema schema, + Set dimensions, Set metrics) { for (String field : fields) { if (queryDimensions.contains(field) || queryMetrics.contains(field)) { continue; @@ -341,59 +269,40 @@ public class SourceRender extends Renderer { } } - private static void addField( - String field, - String oriField, - DataSource datasource, - SemanticSchema schema, - Set dimensions, - Set metrics) { - Optional dimension = - datasource.getDimensions().stream() - .filter(d -> d.getName().equalsIgnoreCase(field)) - .findFirst(); + private static void addField(String field, String oriField, DataSource datasource, + SemanticSchema schema, Set dimensions, Set metrics) { + Optional dimension = datasource.getDimensions().stream() + .filter(d -> d.getName().equalsIgnoreCase(field)).findFirst(); if (dimension.isPresent()) { dimensions.add(oriField); return; } - Optional identify = - datasource.getIdentifiers().stream() - .filter(i -> i.getName().equalsIgnoreCase(field)) - .findFirst(); + Optional identify = datasource.getIdentifiers().stream() + .filter(i -> i.getName().equalsIgnoreCase(field)).findFirst(); if (identify.isPresent()) { dimensions.add(oriField); return; } if (schema.getDimension().containsKey(datasource.getName())) { - Optional dataSourceDim = - schema.getDimension().get(datasource.getName()).stream() - .filter(d -> d.getName().equalsIgnoreCase(field)) - .findFirst(); + Optional dataSourceDim = schema.getDimension().get(datasource.getName()) + .stream().filter(d -> d.getName().equalsIgnoreCase(field)).findFirst(); if (dataSourceDim.isPresent()) { dimensions.add(oriField); return; } } - Optional metric = - datasource.getMeasures().stream() - .filter(m -> m.getName().equalsIgnoreCase(field)) - .findFirst(); + Optional metric = datasource.getMeasures().stream() + .filter(m -> m.getName().equalsIgnoreCase(field)).findFirst(); if (metric.isPresent()) { metrics.add(oriField); return; } - Optional datasourceMetric = - schema.getMetrics().stream() - .filter(m -> m.getName().equalsIgnoreCase(field)) - .findFirst(); + Optional datasourceMetric = schema.getMetrics().stream() + .filter(m -> m.getName().equalsIgnoreCase(field)).findFirst(); if (datasourceMetric.isPresent()) { - Set measures = - datasourceMetric.get().getMetricTypeParams().getMeasures().stream() - .map(m -> m.getName()) - .collect(Collectors.toSet()); - if (datasource.getMeasures().stream() - .map(m -> m.getName()) - .collect(Collectors.toSet()) + Set measures = datasourceMetric.get().getMetricTypeParams().getMeasures() + .stream().map(m -> m.getName()).collect(Collectors.toSet()); + if (datasource.getMeasures().stream().map(m -> m.getName()).collect(Collectors.toSet()) .containsAll(measures)) { metrics.add(oriField); return; @@ -402,25 +311,19 @@ public class SourceRender extends Renderer { } public static boolean isDimension(String name, DataSource datasource, SemanticSchema schema) { - Optional dimension = - datasource.getDimensions().stream() - .filter(d -> d.getName().equalsIgnoreCase(name)) - .findFirst(); + Optional dimension = datasource.getDimensions().stream() + .filter(d -> d.getName().equalsIgnoreCase(name)).findFirst(); if (dimension.isPresent()) { return true; } - Optional identify = - datasource.getIdentifiers().stream() - .filter(i -> i.getName().equalsIgnoreCase(name)) - .findFirst(); + Optional identify = datasource.getIdentifiers().stream() + .filter(i -> i.getName().equalsIgnoreCase(name)).findFirst(); if (identify.isPresent()) { return true; } if (schema.getDimension().containsKey(datasource.getName())) { - Optional dataSourceDim = - schema.getDimension().get(datasource.getName()).stream() - .filter(d -> d.getName().equalsIgnoreCase(name)) - .findFirst(); + Optional dataSourceDim = schema.getDimension().get(datasource.getName()) + .stream().filter(d -> d.getName().equalsIgnoreCase(name)).findFirst(); if (dataSourceDim.isPresent()) { return true; } @@ -430,30 +333,14 @@ public class SourceRender extends Renderer { private static void addTimeDimension(DataSource dataSource, List queryDimension) { if (Materialization.TimePartType.ZIPPER.equals(dataSource.getTimePartType())) { - Optional startTimeOp = - dataSource.getDimensions().stream() - .filter( - d -> - Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase( - d.getType())) - .filter( - d -> - d.getName() - .startsWith( - Constants.MATERIALIZATION_ZIPPER_START)) - .findFirst(); - Optional endTimeOp = - dataSource.getDimensions().stream() - .filter( - d -> - Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase( - d.getType())) - .filter( - d -> - d.getName() - .startsWith( - Constants.MATERIALIZATION_ZIPPER_END)) - .findFirst(); + Optional startTimeOp = dataSource.getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) + .filter(d -> d.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_START)) + .findFirst(); + Optional endTimeOp = dataSource.getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) + .filter(d -> d.getName().startsWith(Constants.MATERIALIZATION_ZIPPER_END)) + .findFirst(); if (startTimeOp.isPresent() && !queryDimension.contains(startTimeOp.get().getName())) { queryDimension.add(startTimeOp.get().getName()); } @@ -461,26 +348,17 @@ public class SourceRender extends Renderer { queryDimension.add(endTimeOp.get().getName()); } } else { - Optional timeOp = - dataSource.getDimensions().stream() - .filter( - d -> - Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase( - d.getType())) - .findFirst(); + Optional timeOp = dataSource.getDimensions().stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) + .findFirst(); if (timeOp.isPresent() && !queryDimension.contains(timeOp.get().getName())) { queryDimension.add(timeOp.get().getName()); } } } - public void render( - MetricQueryParam metricQueryParam, - List dataSources, - SqlValidatorScope scope, - SemanticSchema schema, - boolean nonAgg) - throws Exception { + public void render(MetricQueryParam metricQueryParam, List dataSources, + SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { String queryWhere = metricQueryParam.getWhere(); Set whereFields = new HashSet<>(); List fieldWhere = new ArrayList<>(); @@ -493,17 +371,9 @@ public class SourceRender extends Renderer { } if (dataSources.size() == 1) { DataSource dataSource = dataSources.get(0); - super.tableView = - renderOne( - "", - fieldWhere, - metricQueryParam.getMetrics(), - metricQueryParam.getDimensions(), - metricQueryParam.getWhere(), - dataSource, - scope, - schema, - nonAgg); + super.tableView = renderOne("", fieldWhere, metricQueryParam.getMetrics(), + metricQueryParam.getDimensions(), metricQueryParam.getWhere(), dataSource, + scope, schema, nonAgg); return; } JoinRender joinRender = new JoinRender(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/CalculateAggConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/CalculateAggConverter.java index 3a56e9c46..252dd3184 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/CalculateAggConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/CalculateAggConverter.java @@ -33,9 +33,8 @@ public class CalculateAggConverter implements QueryConverter { String sql(QueryParam queryParam, boolean isOver, boolean asWith, String metricSql); } - public DataSetQueryParam generateSqlCommend( - QueryStatement queryStatement, EngineType engineTypeEnum, String version) - throws Exception { + public DataSetQueryParam generateSqlCommend(QueryStatement queryStatement, + EngineType engineTypeEnum, String version) throws Exception { SqlGenerateUtils sqlGenerateUtils = ContextUtils.getBean(SqlGenerateUtils.class); QueryParam queryParam = queryStatement.getQueryParam(); // 同环比 @@ -53,24 +52,16 @@ public class CalculateAggConverter implements QueryConverter { metricTable.setWhere(where); metricTable.setAggOption(AggOption.AGGREGATION); sqlCommand.setTables(new ArrayList<>(Collections.singletonList(metricTable))); - String sql = - String.format( - "select %s from %s %s %s %s", - sqlGenerateUtils.getSelect(queryParam), - metricTableName, - sqlGenerateUtils.getGroupBy(queryParam), - sqlGenerateUtils.getOrderBy(queryParam), - sqlGenerateUtils.getLimit(queryParam)); + String sql = String.format("select %s from %s %s %s %s", + sqlGenerateUtils.getSelect(queryParam), metricTableName, + sqlGenerateUtils.getGroupBy(queryParam), sqlGenerateUtils.getOrderBy(queryParam), + sqlGenerateUtils.getLimit(queryParam)); if (!sqlGenerateUtils.isSupportWith(engineTypeEnum, version)) { sqlCommand.setSupportWith(false); - sql = - String.format( - "select %s from %s t0 %s %s %s", - sqlGenerateUtils.getSelect(queryParam), - metricTableName, - sqlGenerateUtils.getGroupBy(queryParam), - sqlGenerateUtils.getOrderBy(queryParam), - sqlGenerateUtils.getLimit(queryParam)); + sql = String.format("select %s from %s t0 %s %s %s", + sqlGenerateUtils.getSelect(queryParam), metricTableName, + sqlGenerateUtils.getGroupBy(queryParam), + sqlGenerateUtils.getOrderBy(queryParam), sqlGenerateUtils.getLimit(queryParam)); } sqlCommand.setSql(sql); return sqlCommand; @@ -107,32 +98,25 @@ public class CalculateAggConverter implements QueryConverter { @Override public void convert(QueryStatement queryStatement) throws Exception { Database database = queryStatement.getSemanticModel().getDatabase(); - DataSetQueryParam dataSetQueryParam = - generateSqlCommend( - queryStatement, - EngineType.fromString(database.getType().toUpperCase()), - database.getVersion()); + DataSetQueryParam dataSetQueryParam = generateSqlCommend(queryStatement, + EngineType.fromString(database.getType().toUpperCase()), database.getVersion()); queryStatement.setDataSetQueryParam(dataSetQueryParam); } /** Ratio */ public boolean isRatioAccept(QueryParam queryParam) { - Long ratioFuncNum = - queryParam.getAggregators().stream() - .filter( - f -> - (f.getFunc().equals(AggOperatorEnum.RATIO_ROLL) - || f.getFunc().equals(AggOperatorEnum.RATIO_OVER))) - .count(); + Long ratioFuncNum = queryParam.getAggregators().stream() + .filter(f -> (f.getFunc().equals(AggOperatorEnum.RATIO_ROLL) + || f.getFunc().equals(AggOperatorEnum.RATIO_OVER))) + .count(); if (ratioFuncNum > 0) { return true; } return false; } - public DataSetQueryParam generateRatioSqlCommand( - QueryStatement queryStatement, EngineType engineTypeEnum, String version) - throws Exception { + public DataSetQueryParam generateRatioSqlCommand(QueryStatement queryStatement, + EngineType engineTypeEnum, String version) throws Exception { SqlGenerateUtils sqlGenerateUtils = ContextUtils.getBean(SqlGenerateUtils.class); QueryParam queryParam = queryStatement.getQueryParam(); check(queryParam); @@ -161,21 +145,11 @@ public class CalculateAggConverter implements QueryConverter { sqlCommand.setSupportWith(false); } if (!engineTypeEnum.equals(engineTypeEnum.CLICKHOUSE)) { - sql = - new MysqlEngineSql() - .sql( - queryParam, - isOver, - sqlCommand.isSupportWith(), - metricTableName); + sql = new MysqlEngineSql().sql(queryParam, isOver, sqlCommand.isSupportWith(), + metricTableName); } else { - sql = - new CkEngineSql() - .sql( - queryParam, - isOver, - sqlCommand.isSupportWith(), - metricTableName); + sql = new CkEngineSql().sql(queryParam, isOver, sqlCommand.isSupportWith(), + metricTableName); } break; default: @@ -187,27 +161,17 @@ public class CalculateAggConverter implements QueryConverter { public class H2EngineSql implements EngineSql { public String getOverSelect(QueryParam queryParam, boolean isOver) { - String aggStr = - queryParam.getAggregators().stream() - .map( - f -> { - if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) - || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { - return String.format( - "( (%s-%s_roll)/cast(%s_roll as DOUBLE) ) as %s_%s,%s", - f.getColumn(), - f.getColumn(), - f.getColumn(), - f.getColumn(), - f.getFunc().getOperator(), - f.getColumn()); - } else { - return f.getColumn(); - } - }) - .collect(Collectors.joining(",")); - return CollectionUtils.isEmpty(queryParam.getGroups()) - ? aggStr + String aggStr = queryParam.getAggregators().stream().map(f -> { + if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) + || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { + return String.format("( (%s-%s_roll)/cast(%s_roll as DOUBLE) ) as %s_%s,%s", + f.getColumn(), f.getColumn(), f.getColumn(), f.getColumn(), + f.getFunc().getOperator(), f.getColumn()); + } else { + return f.getColumn(); + } + }).collect(Collectors.joining(",")); + return CollectionUtils.isEmpty(queryParam.getGroups()) ? aggStr : String.join(",", queryParam.getGroups()) + "," + aggStr; } @@ -227,48 +191,31 @@ public class CalculateAggConverter implements QueryConverter { return ""; } - public String getJoinOn( - QueryParam queryParam, boolean isOver, String aliasLeft, String aliasRight) { + public String getJoinOn(QueryParam queryParam, boolean isOver, String aliasLeft, + String aliasRight) { String timeDim = getTimeDim(queryParam); String timeSpan = getTimeSpan(queryParam, isOver, true); - String aggStr = - queryParam.getAggregators().stream() - .map( - f -> { - if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) - || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { - if (queryParam - .getDateInfo() - .getPeriod() - .equals(DatePeriodEnum.MONTH)) { - return String.format( - "%s is not null and %s = FORMATDATETIME(DATEADD(%s,CONCAT(%s,'-01')),'yyyy-MM') ", - aliasRight + timeDim, - aliasLeft + timeDim, - timeSpan, - aliasRight + timeDim); - } - if (queryParam - .getDateInfo() - .getPeriod() - .equals(DatePeriodEnum.WEEK) - && isOver) { - return String.format( - " DATE_TRUNC('week',DATEADD(%s,%s) ) = %s ", - getTimeSpan(queryParam, isOver, false), - aliasLeft + timeDim, - aliasRight + timeDim); - } - return String.format( - "%s = TIMESTAMPADD(%s,%s) ", - aliasLeft + timeDim, - timeSpan, - aliasRight + timeDim); - } else { - return f.getColumn(); - } - }) - .collect(Collectors.joining(" and ")); + String aggStr = queryParam.getAggregators().stream().map(f -> { + if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) + || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { + if (queryParam.getDateInfo().getPeriod().equals(DatePeriodEnum.MONTH)) { + return String.format( + "%s is not null and %s = FORMATDATETIME(DATEADD(%s,CONCAT(%s,'-01')),'yyyy-MM') ", + aliasRight + timeDim, aliasLeft + timeDim, timeSpan, + aliasRight + timeDim); + } + if (queryParam.getDateInfo().getPeriod().equals(DatePeriodEnum.WEEK) + && isOver) { + return String.format(" DATE_TRUNC('week',DATEADD(%s,%s) ) = %s ", + getTimeSpan(queryParam, isOver, false), aliasLeft + timeDim, + aliasRight + timeDim); + } + return String.format("%s = TIMESTAMPADD(%s,%s) ", aliasLeft + timeDim, timeSpan, + aliasRight + timeDim); + } else { + return f.getColumn(); + } + }).collect(Collectors.joining(" and ")); List groups = new ArrayList<>(); for (String group : queryParam.getGroups()) { if (group.equalsIgnoreCase(timeDim)) { @@ -276,71 +223,48 @@ public class CalculateAggConverter implements QueryConverter { } groups.add(aliasLeft + group + " = " + aliasRight + group); } - return CollectionUtils.isEmpty(groups) - ? aggStr + return CollectionUtils.isEmpty(groups) ? aggStr : String.join(" and ", groups) + " and " + aggStr + " "; } @Override public String sql(QueryParam queryParam, boolean isOver, boolean asWith, String metricSql) { - String sql = - String.format( - "select %s from ( select %s , %s from %s t0 left join %s t1 on %s ) metric_tb_src %s %s ", - getOverSelect(queryParam, isOver), - getAllSelect(queryParam, "t0."), - getAllJoinSelect(queryParam, "t1."), - metricSql, - metricSql, - getJoinOn(queryParam, isOver, "t0.", "t1."), - getOrderBy(queryParam), - getLimit(queryParam)); + String sql = String.format( + "select %s from ( select %s , %s from %s t0 left join %s t1 on %s ) metric_tb_src %s %s ", + getOverSelect(queryParam, isOver), getAllSelect(queryParam, "t0."), + getAllJoinSelect(queryParam, "t1."), metricSql, metricSql, + getJoinOn(queryParam, isOver, "t0.", "t1."), getOrderBy(queryParam), + getLimit(queryParam)); return sql; } } public class CkEngineSql extends MysqlEngineSql { - public String getJoinOn( - QueryParam queryParam, boolean isOver, String aliasLeft, String aliasRight) { + public String getJoinOn(QueryParam queryParam, boolean isOver, String aliasLeft, + String aliasRight) { String timeDim = getTimeDim(queryParam); String timeSpan = "INTERVAL " + getTimeSpan(queryParam, isOver, true); - String aggStr = - queryParam.getAggregators().stream() - .map( - f -> { - if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) - || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { - if (queryParam - .getDateInfo() - .getPeriod() - .equals(DatePeriodEnum.MONTH)) { - return String.format( - "toDate(CONCAT(%s,'-01')) = date_add(toDate(CONCAT(%s,'-01')),%s) ", - aliasLeft + timeDim, - aliasRight + timeDim, - timeSpan); - } - if (queryParam - .getDateInfo() - .getPeriod() - .equals(DatePeriodEnum.WEEK) - && isOver) { - return String.format( - "toMonday(date_add(%s ,INTERVAL %s) ) = %s", - aliasLeft + timeDim, - getTimeSpan(queryParam, isOver, false), - aliasRight + timeDim); - } - return String.format( - "%s = date_add(%s,%s) ", - aliasLeft + timeDim, - aliasRight + timeDim, - timeSpan); - } else { - return f.getColumn(); - } - }) - .collect(Collectors.joining(" and ")); + String aggStr = queryParam.getAggregators().stream().map(f -> { + if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) + || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { + if (queryParam.getDateInfo().getPeriod().equals(DatePeriodEnum.MONTH)) { + return String.format( + "toDate(CONCAT(%s,'-01')) = date_add(toDate(CONCAT(%s,'-01')),%s) ", + aliasLeft + timeDim, aliasRight + timeDim, timeSpan); + } + if (queryParam.getDateInfo().getPeriod().equals(DatePeriodEnum.WEEK) + && isOver) { + return String.format("toMonday(date_add(%s ,INTERVAL %s) ) = %s", + aliasLeft + timeDim, getTimeSpan(queryParam, isOver, false), + aliasRight + timeDim); + } + return String.format("%s = date_add(%s,%s) ", aliasLeft + timeDim, + aliasRight + timeDim, timeSpan); + } else { + return f.getColumn(); + } + }).collect(Collectors.joining(" and ")); List groups = new ArrayList<>(); for (String group : queryParam.getGroups()) { if (group.equalsIgnoreCase(timeDim)) { @@ -348,8 +272,7 @@ public class CalculateAggConverter implements QueryConverter { } groups.add(aliasLeft + group + " = " + aliasRight + group); } - return CollectionUtils.isEmpty(groups) - ? aggStr + return CollectionUtils.isEmpty(groups) ? aggStr : String.join(" and ", groups) + " and " + aggStr + " "; } @@ -358,25 +281,17 @@ public class CalculateAggConverter implements QueryConverter { if (!asWith) { return String.format( "select %s from ( select %s , %s from %s t0 left join %s t1 on %s ) metric_tb_src %s %s ", - getOverSelect(queryParam, isOver), - getAllSelect(queryParam, "t0."), - getAllJoinSelect(queryParam, "t1."), - metricSql, - metricSql, - getJoinOn(queryParam, isOver, "t0.", "t1."), - getOrderBy(queryParam), + getOverSelect(queryParam, isOver), getAllSelect(queryParam, "t0."), + getAllJoinSelect(queryParam, "t1."), metricSql, metricSql, + getJoinOn(queryParam, isOver, "t0.", "t1."), getOrderBy(queryParam), getLimit(queryParam)); } return String.format( ",t0 as (select * from %s),t1 as (select * from %s) select %s from ( select %s , %s " + "from t0 left join t1 on %s ) metric_tb_src %s %s ", - metricSql, - metricSql, - getOverSelect(queryParam, isOver), - getAllSelect(queryParam, "t0."), - getAllJoinSelect(queryParam, "t1."), - getJoinOn(queryParam, isOver, "t0.", "t1."), - getOrderBy(queryParam), + metricSql, metricSql, getOverSelect(queryParam, isOver), + getAllSelect(queryParam, "t0."), getAllJoinSelect(queryParam, "t1."), + getJoinOn(queryParam, isOver, "t0.", "t1."), getOrderBy(queryParam), getLimit(queryParam)); } } @@ -400,72 +315,44 @@ public class CalculateAggConverter implements QueryConverter { } public String getOverSelect(QueryParam queryParam, boolean isOver) { - String aggStr = - queryParam.getAggregators().stream() - .map( - f -> { - if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) - || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { - return String.format( - "if(%s_roll!=0, (%s-%s_roll)/%s_roll , 0) as %s_%s,%s", - f.getColumn(), - f.getColumn(), - f.getColumn(), - f.getColumn(), - f.getColumn(), - f.getFunc().getOperator(), - f.getColumn()); - } else { - return f.getColumn(); - } - }) - .collect(Collectors.joining(",")); - return CollectionUtils.isEmpty(queryParam.getGroups()) - ? aggStr + String aggStr = queryParam.getAggregators().stream().map(f -> { + if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) + || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { + return String.format("if(%s_roll!=0, (%s-%s_roll)/%s_roll , 0) as %s_%s,%s", + f.getColumn(), f.getColumn(), f.getColumn(), f.getColumn(), + f.getColumn(), f.getFunc().getOperator(), f.getColumn()); + } else { + return f.getColumn(); + } + }).collect(Collectors.joining(",")); + return CollectionUtils.isEmpty(queryParam.getGroups()) ? aggStr : String.join(",", queryParam.getGroups()) + "," + aggStr; } - public String getJoinOn( - QueryParam queryParam, boolean isOver, String aliasLeft, String aliasRight) { + public String getJoinOn(QueryParam queryParam, boolean isOver, String aliasLeft, + String aliasRight) { String timeDim = getTimeDim(queryParam); String timeSpan = "INTERVAL " + getTimeSpan(queryParam, isOver, true); - String aggStr = - queryParam.getAggregators().stream() - .map( - f -> { - if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) - || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { - if (queryParam - .getDateInfo() - .getPeriod() - .equals(DatePeriodEnum.MONTH)) { - return String.format( - "%s = DATE_FORMAT(date_add(CONCAT(%s,'-01'), %s),'%%Y-%%m') ", - aliasLeft + timeDim, - aliasRight + timeDim, - timeSpan); - } - if (queryParam - .getDateInfo() - .getPeriod() - .equals(DatePeriodEnum.WEEK) - && isOver) { - return String.format( - "to_monday(date_add(%s ,INTERVAL %s) ) = %s", - aliasLeft + timeDim, - getTimeSpan(queryParam, isOver, false), - aliasRight + timeDim); - } - return String.format( - "%s = date_add(%s,%s) ", - aliasLeft + timeDim, - aliasRight + timeDim, - timeSpan); - } else { - return f.getColumn(); - } - }) - .collect(Collectors.joining(" and ")); + String aggStr = queryParam.getAggregators().stream().map(f -> { + if (f.getFunc().equals(AggOperatorEnum.RATIO_OVER) + || f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) { + if (queryParam.getDateInfo().getPeriod().equals(DatePeriodEnum.MONTH)) { + return String.format( + "%s = DATE_FORMAT(date_add(CONCAT(%s,'-01'), %s),'%%Y-%%m') ", + aliasLeft + timeDim, aliasRight + timeDim, timeSpan); + } + if (queryParam.getDateInfo().getPeriod().equals(DatePeriodEnum.WEEK) + && isOver) { + return String.format("to_monday(date_add(%s ,INTERVAL %s) ) = %s", + aliasLeft + timeDim, getTimeSpan(queryParam, isOver, false), + aliasRight + timeDim); + } + return String.format("%s = date_add(%s,%s) ", aliasLeft + timeDim, + aliasRight + timeDim, timeSpan); + } else { + return f.getColumn(); + } + }).collect(Collectors.joining(" and ")); List groups = new ArrayList<>(); for (String group : queryParam.getGroups()) { if (group.equalsIgnoreCase(timeDim)) { @@ -473,38 +360,26 @@ public class CalculateAggConverter implements QueryConverter { } groups.add(aliasLeft + group + " = " + aliasRight + group); } - return CollectionUtils.isEmpty(groups) - ? aggStr + return CollectionUtils.isEmpty(groups) ? aggStr : String.join(" and ", groups) + " and " + aggStr + " "; } @Override public String sql(QueryParam queryParam, boolean isOver, boolean asWith, String metricSql) { - String sql = - String.format( - "select %s from ( select %s , %s from %s t0 left join %s t1 on %s ) metric_tb_src %s %s ", - getOverSelect(queryParam, isOver), - getAllSelect(queryParam, "t0."), - getAllJoinSelect(queryParam, "t1."), - metricSql, - metricSql, - getJoinOn(queryParam, isOver, "t0.", "t1."), - getOrderBy(queryParam), - getLimit(queryParam)); + String sql = String.format( + "select %s from ( select %s , %s from %s t0 left join %s t1 on %s ) metric_tb_src %s %s ", + getOverSelect(queryParam, isOver), getAllSelect(queryParam, "t0."), + getAllJoinSelect(queryParam, "t1."), metricSql, metricSql, + getJoinOn(queryParam, isOver, "t0.", "t1."), getOrderBy(queryParam), + getLimit(queryParam)); return sql; } } private String getAllJoinSelect(QueryParam queryParam, String alias) { - String aggStr = - queryParam.getAggregators().stream() - .map( - f -> - getSelectField(f, alias) - + " as " - + getSelectField(f, "") - + "_roll") - .collect(Collectors.joining(",")); + String aggStr = queryParam.getAggregators().stream() + .map(f -> getSelectField(f, alias) + " as " + getSelectField(f, "") + "_roll") + .collect(Collectors.joining(",")); List groups = new ArrayList<>(); for (String group : queryParam.getGroups()) { groups.add(alias + group + " as " + group + "_roll"); @@ -514,8 +389,7 @@ public class CalculateAggConverter implements QueryConverter { private String getGroupDimWithOutTime(QueryParam queryParam) { String timeDim = getTimeDim(queryParam); - return queryParam.getGroups().stream() - .filter(f -> !f.equalsIgnoreCase(timeDim)) + return queryParam.getGroups().stream().filter(f -> !f.equalsIgnoreCase(timeDim)) .collect(Collectors.joining(",")); } @@ -532,12 +406,9 @@ public class CalculateAggConverter implements QueryConverter { } private String getAllSelect(QueryParam queryParam, String alias) { - String aggStr = - queryParam.getAggregators().stream() - .map(f -> getSelectField(f, alias)) - .collect(Collectors.joining(",")); - return CollectionUtils.isEmpty(queryParam.getGroups()) - ? aggStr + String aggStr = queryParam.getAggregators().stream().map(f -> getSelectField(f, alias)) + .collect(Collectors.joining(",")); + return CollectionUtils.isEmpty(queryParam.getGroups()) ? aggStr : alias + String.join("," + alias, queryParam.getGroups()) + "," + aggStr; } @@ -562,22 +433,16 @@ public class CalculateAggConverter implements QueryConverter { } private boolean isOverRatio(QueryParam queryParam) { - Long overCt = - queryParam.getAggregators().stream() - .filter(f -> f.getFunc().equals(AggOperatorEnum.RATIO_OVER)) - .count(); + Long overCt = queryParam.getAggregators().stream() + .filter(f -> f.getFunc().equals(AggOperatorEnum.RATIO_OVER)).count(); return overCt > 0; } private void check(QueryParam queryParam) throws Exception { - Long ratioOverNum = - queryParam.getAggregators().stream() - .filter(f -> f.getFunc().equals(AggOperatorEnum.RATIO_OVER)) - .count(); - Long ratioRollNum = - queryParam.getAggregators().stream() - .filter(f -> f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)) - .count(); + Long ratioOverNum = queryParam.getAggregators().stream() + .filter(f -> f.getFunc().equals(AggOperatorEnum.RATIO_OVER)).count(); + Long ratioRollNum = queryParam.getAggregators().stream() + .filter(f -> f.getFunc().equals(AggOperatorEnum.RATIO_ROLL)).count(); if (ratioOverNum > 0 && ratioRollNum > 0) { throw new Exception("not support over ratio and roll ratio together "); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/DefaultDimValueConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/DefaultDimValueConverter.java index 28f0f8572..b06e363f4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/DefaultDimValueConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/DefaultDimValueConverter.java @@ -34,18 +34,16 @@ public class DefaultDimValueConverter implements QueryConverter { @Override public void convert(QueryStatement queryStatement) { - List dimensions = - queryStatement.getSemanticModel().getDimensions().stream() - .filter(dimension -> !CollectionUtils.isEmpty(dimension.getDefaultValues())) - .collect(Collectors.toList()); + List dimensions = queryStatement.getSemanticModel().getDimensions().stream() + .filter(dimension -> !CollectionUtils.isEmpty(dimension.getDefaultValues())) + .collect(Collectors.toList()); if (CollectionUtils.isEmpty(dimensions)) { return; } String sql = queryStatement.getDataSetQueryParam().getSql(); - List whereFields = - SqlSelectHelper.getWhereFields(sql).stream() - .filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) - .collect(Collectors.toList()); + List whereFields = SqlSelectHelper.getWhereFields(sql).stream() + .filter(field -> !TimeDimensionEnum.containsTimeDimension(field)) + .collect(Collectors.toList()); if (!CollectionUtils.isEmpty(whereFields)) { return; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/ParserDefaultConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/ParserDefaultConverter.java index 731df5cbd..33102db14 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/ParserDefaultConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/ParserDefaultConverter.java @@ -42,8 +42,8 @@ public class ParserDefaultConverter implements QueryConverter { BeanUtils.copyProperties(metricReq, metricQueryParam); } - public MetricQueryParam generateSqlCommand( - QueryParam queryParam, QueryStatement queryStatement) { + public MetricQueryParam generateSqlCommand(QueryParam queryParam, + QueryStatement queryStatement) { SqlGenerateUtils sqlGenerateUtils = ContextUtils.getBean(SqlGenerateUtils.class); MetricQueryParam metricQueryParam = new MetricQueryParam(); metricQueryParam.setMetrics(queryParam.getMetrics()); @@ -52,10 +52,9 @@ public class ParserDefaultConverter implements QueryConverter { log.info("in generateSqlCommend, complete where:{}", where); metricQueryParam.setWhere(where); - metricQueryParam.setOrder( - queryParam.getOrders().stream() - .map(order -> new ColumnOrder(order.getColumn(), order.getDirection())) - .collect(Collectors.toList())); + metricQueryParam.setOrder(queryParam.getOrders().stream() + .map(order -> new ColumnOrder(order.getColumn(), order.getDirection())) + .collect(Collectors.toList())); metricQueryParam.setLimit(queryParam.getLimit()); // support detail query diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlVariableParseConverter.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlVariableParseConverter.java index e6a05d302..179ead737 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlVariableParseConverter.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/translator/converter/SqlVariableParseConverter.java @@ -33,19 +33,14 @@ public class SqlVariableParseConverter implements QueryConverter { return; } for (ModelResp modelResp : modelResps) { - if (ModelDefineType.SQL_QUERY - .getName() + if (ModelDefineType.SQL_QUERY.getName() .equalsIgnoreCase(modelResp.getModelDetail().getQueryType())) { String sqlParsed = - SqlVariableParseUtils.parse( - modelResp.getModelDetail().getSqlQuery(), + SqlVariableParseUtils.parse(modelResp.getModelDetail().getSqlQuery(), modelResp.getModelDetail().getSqlVariables(), queryStatement.getQueryParam().getParams()); - DataSource dataSource = - queryStatement - .getSemanticModel() - .getDatasourceMap() - .get(modelResp.getBizName()); + DataSource dataSource = queryStatement.getSemanticModel().getDatasourceMap() + .get(modelResp.getBizName()); dataSource.setSqlQuery(sqlParsed); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java index 2e37a3c6c..37305db94 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/ComponentFactory.java @@ -120,15 +120,13 @@ public class ComponentFactory { } private static List init(Class factoryType, List list) { - list.addAll( - SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader())); + list.addAll(SpringFactoriesLoader.loadFactories(factoryType, + Thread.currentThread().getContextClassLoader())); return list; } private static T init(Class factoryType) { - return SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader()) - .get(0); + return SpringFactoriesLoader + .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/DataTransformUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/DataTransformUtils.java index 691e097b7..3fcf769cf 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/DataTransformUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/DataTransformUtils.java @@ -15,11 +15,8 @@ import java.util.stream.Collectors; /** transform query results to return the users */ public class DataTransformUtils { - public static List> transform( - List> originalData, - String metric, - List groups, - DateConf dateConf) { + public static List> transform(List> originalData, + String metric, List groups, DateConf dateConf) { List dateList = dateConf.getDateList(); List> transposedData = new ArrayList<>(); for (Map originalRow : originalData) { @@ -29,14 +26,12 @@ public class DataTransformUtils { transposedRow.put(key, originalRow.get(key)); } } - transposedRow.put( - String.valueOf(originalRow.get(getTimeDimension(dateConf))), + transposedRow.put(String.valueOf(originalRow.get(getTimeDimension(dateConf))), originalRow.get(metric)); transposedData.add(transposedRow); } - Map>> dataMerge = - transposedData.stream() - .collect(Collectors.groupingBy(row -> getRowKey(row, groups))); + Map>> dataMerge = transposedData.stream() + .collect(Collectors.groupingBy(row -> getRowKey(row, groups))); List> resultData = Lists.newArrayList(); for (List> data : dataMerge.values()) { Map rowData = new HashMap<>(); diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDataSourceUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDataSourceUtils.java index 11ef81875..1d4729c4a 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDataSourceUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDataSourceUtils.java @@ -31,7 +31,8 @@ import static com.tencent.supersonic.common.pojo.Constants.SPACE; @Slf4j public class JdbcDataSourceUtils { - @Getter private static Set releaseSourceSet = new HashSet(); + @Getter + private static Set releaseSourceSet = new HashSet(); private JdbcDataSource jdbcDataSource; public JdbcDataSourceUtils(JdbcDataSource jdbcDataSource) { @@ -46,9 +47,8 @@ public class JdbcDataSourceUtils { log.error(e.toString(), e); return false; } - try (Connection con = - DriverManager.getConnection( - database.getUrl(), database.getUsername(), database.passwordDecrypt())) { + try (Connection con = DriverManager.getConnection(database.getUrl(), database.getUsername(), + database.passwordDecrypt())) { return con != null; } catch (SQLException e) { log.error(e.toString(), e); @@ -116,8 +116,7 @@ public class JdbcDataSourceUtils { log.error("e", e); } - if (!StringUtils.isEmpty(className) - && !className.contains("com.sun.proxy") + if (!StringUtils.isEmpty(className) && !className.contains("com.sun.proxy") && !className.contains("net.sf.cglib.proxy")) { return className; } @@ -129,13 +128,8 @@ public class JdbcDataSourceUtils { throw new RuntimeException("Not supported data type: jdbcUrl=" + jdbcUrl); } - public static String getKey( - String name, - String jdbcUrl, - String username, - String password, - String version, - boolean isExt) { + public static String getKey(String name, String jdbcUrl, String username, String password, + String version, boolean isExt) { StringBuilder sb = new StringBuilder(); @@ -165,10 +159,8 @@ public class JdbcDataSourceUtils { return dataSource.getConnection(); } catch (Exception e) { log.error("Get connection error, jdbcUrl:{}, e:{}", database.getUrl(), e); - throw new RuntimeException( - "Get connection error, jdbcUrl:" - + database.getUrl() - + " you can try again later or reset datasource"); + throw new RuntimeException("Get connection error, jdbcUrl:" + database.getUrl() + + " you can try again later or reset datasource"); } } return conn; @@ -176,7 +168,7 @@ public class JdbcDataSourceUtils { private Connection getConnectionWithRetry(Database database) { int rc = 1; - for (; ; ) { + for (;;) { if (rc > 3) { return null; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDuckDbUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDuckDbUtils.java index 89eb30fd9..9e2bcacf9 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDuckDbUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/JdbcDuckDbUtils.java @@ -11,14 +11,8 @@ import java.util.stream.Collectors; /** tools functions to duckDb query */ public class JdbcDuckDbUtils { - public static void attachMysql( - DuckDbSource duckDbSource, - String host, - Integer port, - String user, - String password, - String database) - throws Exception { + public static void attachMysql(DuckDbSource duckDbSource, String host, Integer port, + String user, String password, String database) throws Exception { try { duckDbSource.execute("INSTALL mysql"); duckDbSource.execute("load mysql"); @@ -40,25 +34,20 @@ public class JdbcDuckDbUtils { if (!queryResultWithColumns.getResultList().isEmpty()) { return queryResultWithColumns.getResultList().stream() .filter(l -> l.containsKey("name") && Objects.nonNull(l.get("name"))) - .map(l -> (String) l.get("name")) - .collect(Collectors.toList()); + .map(l -> (String) l.get("name")).collect(Collectors.toList()); } return new ArrayList<>(); } - public static List getParquetPartition( - DuckDbSource duckDbSource, String parquetPath, String partitionName) throws Exception { + public static List getParquetPartition(DuckDbSource duckDbSource, String parquetPath, + String partitionName) throws Exception { SemanticQueryResp queryResultWithColumns = new SemanticQueryResp(); - duckDbSource.query( - String.format( - "SELECT distinct %s as partition FROM read_parquet('%s')", - partitionName, parquetPath), - queryResultWithColumns); + duckDbSource.query(String.format("SELECT distinct %s as partition FROM read_parquet('%s')", + partitionName, parquetPath), queryResultWithColumns); if (!queryResultWithColumns.getResultList().isEmpty()) { return queryResultWithColumns.getResultList().stream() .filter(l -> l.containsKey("partition") && Objects.nonNull(l.get("partition"))) - .map(l -> (String) l.get("partition")) - .collect(Collectors.toList()); + .map(l -> (String) l.get("partition")).collect(Collectors.toList()); } return new ArrayList<>(); } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SchemaMatchHelper.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SchemaMatchHelper.java index 7ca9930f3..97e9b0519 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SchemaMatchHelper.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SchemaMatchHelper.java @@ -19,24 +19,21 @@ public class SchemaMatchHelper { } Set metricDimensionDetectWordSet = - matches.stream() - .filter(SchemaMatchHelper::isMetricOrDimension) - .map(SchemaElementMatch::getDetectWord) - .collect(Collectors.toSet()); + matches.stream().filter(SchemaMatchHelper::isMetricOrDimension) + .map(SchemaElementMatch::getDetectWord).collect(Collectors.toSet()); - matches.removeIf( - elementMatch -> { - if (!isMetricOrDimension(elementMatch)) { - return false; - } - for (String detectWord : metricDimensionDetectWordSet) { - if (detectWord.startsWith(elementMatch.getDetectWord()) - && detectWord.length() > elementMatch.getDetectWord().length()) { - return true; - } - } - return false; - }); + matches.removeIf(elementMatch -> { + if (!isMetricOrDimension(elementMatch)) { + return false; + } + for (String detectWord : metricDimensionDetectWordSet) { + if (detectWord.startsWith(elementMatch.getDetectWord()) + && detectWord.length() > elementMatch.getDetectWord().length()) { + return true; + } + } + return false; + }); } private static boolean isMetricOrDimension(SchemaElementMatch elementMatch) { diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java index 9598f3034..eb734f1c4 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlGenerateUtils.java @@ -54,9 +54,7 @@ public class SqlGenerateUtils { private final ExecutorConfig executorConfig; - public SqlGenerateUtils( - SqlFilterUtils sqlFilterUtils, - DateModeUtils dateModeUtils, + public SqlGenerateUtils(SqlFilterUtils sqlFilterUtils, DateModeUtils dateModeUtils, ExecutorConfig executorConfig) { this.sqlFilterUtils = sqlFilterUtils; this.dateModeUtils = dateModeUtils; @@ -95,22 +93,16 @@ public class SqlGenerateUtils { } public String getSelect(QueryParam queryParam) { - String aggStr = - queryParam.getAggregators().stream() - .map(this::getSelectField) - .collect(Collectors.joining(",")); - return CollectionUtils.isEmpty(queryParam.getGroups()) - ? aggStr + String aggStr = queryParam.getAggregators().stream().map(this::getSelectField) + .collect(Collectors.joining(",")); + return CollectionUtils.isEmpty(queryParam.getGroups()) ? aggStr : String.join(",", queryParam.getGroups()) + "," + aggStr; } public String getSelect(QueryParam queryParam, Map deriveMetrics) { - String aggStr = - queryParam.getAggregators().stream() - .map(a -> getSelectField(a, deriveMetrics)) - .collect(Collectors.joining(",")); - return CollectionUtils.isEmpty(queryParam.getGroups()) - ? aggStr + String aggStr = queryParam.getAggregators().stream() + .map(a -> getSelectField(a, deriveMetrics)).collect(Collectors.joining(",")); + return CollectionUtils.isEmpty(queryParam.getGroups()) ? aggStr : String.join(",", queryParam.getGroups()) + "," + aggStr; } @@ -121,20 +113,12 @@ public class SqlGenerateUtils { if (CollectionUtils.isEmpty(agg.getArgs())) { return agg.getFunc() + "( " + agg.getColumn() + " ) AS " + agg.getColumn() + " "; } - return agg.getFunc() - + "( " + return agg.getFunc() + "( " + agg.getArgs().stream() - .map( - arg -> - arg.equals(agg.getColumn()) - ? arg - : (StringUtils.isNumeric(arg) - ? arg - : ("'" + arg + "'"))) + .map(arg -> arg.equals(agg.getColumn()) ? arg + : (StringUtils.isNumeric(arg) ? arg : ("'" + arg + "'"))) .collect(Collectors.joining(",")) - + " ) AS " - + agg.getColumn() - + " "; + + " ) AS " + agg.getColumn() + " "; } public String getSelectField(final Aggregator agg, Map deriveMetrics) { @@ -155,10 +139,9 @@ public class SqlGenerateUtils { if (CollectionUtils.isEmpty(queryParam.getOrders())) { return ""; } - return "order by " - + queryParam.getOrders().stream() - .map(order -> " " + order.getColumn() + " " + order.getDirection() + " ") - .collect(Collectors.joining(",")); + return "order by " + queryParam.getOrders().stream() + .map(order -> " " + order.getColumn() + " " + order.getDirection() + " ") + .collect(Collectors.joining(",")); } public String getOrderBy(QueryParam queryParam, Map deriveMetrics) { @@ -169,18 +152,11 @@ public class SqlGenerateUtils { .anyMatch(o -> deriveMetrics.containsKey(o.getColumn()))) { return getOrderBy(queryParam); } - return "order by " - + queryParam.getOrders().stream() - .map( - order -> - " " - + (deriveMetrics.containsKey(order.getColumn()) - ? deriveMetrics.get(order.getColumn()) - : order.getColumn()) - + " " - + order.getDirection() - + " ") - .collect(Collectors.joining(",")); + return "order by " + queryParam.getOrders().stream() + .map(order -> " " + (deriveMetrics.containsKey(order.getColumn()) + ? deriveMetrics.get(order.getColumn()) + : order.getColumn()) + " " + order.getDirection() + " ") + .collect(Collectors.joining(",")); } public String generateWhere(QueryParam queryParam, ItemDateResp itemDateResp) { @@ -190,8 +166,8 @@ public class SqlGenerateUtils { return mergeDateWhereClause(queryParam, whereClauseFromFilter, whereFromDate); } - private String mergeDateWhereClause( - QueryParam queryParam, String whereClauseFromFilter, String whereFromDate) { + private String mergeDateWhereClause(QueryParam queryParam, String whereClauseFromFilter, + String whereFromDate) { if (StringUtils.isNotEmpty(whereFromDate) && StringUtils.isNotEmpty(whereClauseFromFilter)) { return String.format("%s AND (%s)", whereFromDate, whereClauseFromFilter); @@ -209,9 +185,8 @@ public class SqlGenerateUtils { } public String getDateWhereClause(DateConf dateInfo, ItemDateResp dateDate) { - if (Objects.isNull(dateDate) - || StringUtils.isEmpty(dateDate.getStartDate()) - && StringUtils.isEmpty(dateDate.getEndDate())) { + if (Objects.isNull(dateDate) || StringUtils.isEmpty(dateDate.getStartDate()) + && StringUtils.isEmpty(dateDate.getEndDate())) { if (dateInfo.getDateMode().equals(DateConf.DateMode.LIST)) { return dateModeUtils.listDateStr(dateInfo); } @@ -228,8 +203,8 @@ public class SqlGenerateUtils { return dateModeUtils.getDateWhereStr(dateInfo, dateDate); } - public Triple getBeginEndTime( - QueryParam queryParam, ItemDateResp dataDate) { + public Triple getBeginEndTime(QueryParam queryParam, + ItemDateResp dataDate) { if (Objects.isNull(queryParam.getDateInfo())) { return Triple.of("", "", ""); } @@ -243,16 +218,13 @@ public class SqlGenerateUtils { case BETWEEN: return Triple.of(dateInfo, dateConf.getStartDate(), dateConf.getEndDate()); case LIST: - return Triple.of( - dateInfo, - Collections.min(dateConf.getDateList()), + return Triple.of(dateInfo, Collections.min(dateConf.getDateList()), Collections.max(dateConf.getDateList())); case RECENT: LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMin = dateMax.minusDays(dateConf.getUnit() - 1); if (Objects.isNull(dataDate)) { - return Triple.of( - dateInfo, + return Triple.of(dateInfo, dateMin.format(DateTimeFormatter.ofPattern(DAY_FORMAT)), dateMax.format(DateTimeFormatter.ofPattern(DAY_FORMAT))); } @@ -270,11 +242,8 @@ public class SqlGenerateUtils { dateModeUtils.recentMonth(dataDate, dateConf); Optional minBegins = rets.stream().map(i -> i.left).sorted().findFirst(); - Optional maxBegins = - rets.stream() - .map(i -> i.right) - .sorted(Comparator.reverseOrder()) - .findFirst(); + Optional maxBegins = rets.stream().map(i -> i.right) + .sorted(Comparator.reverseOrder()).findFirst(); if (minBegins.isPresent() && maxBegins.isPresent()) { return Triple.of(dateInfo, minBegins.get(), maxBegins.get()); } @@ -290,13 +259,11 @@ public class SqlGenerateUtils { } public boolean isSupportWith(EngineType engineTypeEnum, String version) { - if (engineTypeEnum.equals(EngineType.MYSQL) - && Objects.nonNull(version) + if (engineTypeEnum.equals(EngineType.MYSQL) && Objects.nonNull(version) && version.startsWith(executorConfig.getMysqlLowVersion())) { return false; } - if (engineTypeEnum.equals(EngineType.CLICKHOUSE) - && Objects.nonNull(version) + if (engineTypeEnum.equals(EngineType.CLICKHOUSE) && Objects.nonNull(version) && StringUtil.compareVersion(version, executorConfig.getCkLowVersion()) < 0) { return false; } @@ -307,44 +274,28 @@ public class SqlGenerateUtils { return modelBizName + UNDERLINE + executorConfig.getInternalMetricNameSuffix(); } - public String generateDerivedMetric( - final List metricResps, - final Set allFields, - final Map allMeasures, - final List dimensionResps, - final String expression, - final MetricDefineType metricDefineType, - AggOption aggOption, - Set visitedMetric, - Set measures, - Set dimensions) { + public String generateDerivedMetric(final List metricResps, + final Set allFields, final Map allMeasures, + final List dimensionResps, final String expression, + final MetricDefineType metricDefineType, AggOption aggOption, Set visitedMetric, + Set measures, Set dimensions) { Set fields = SqlSelectHelper.getColumnFromExpr(expression); if (!CollectionUtils.isEmpty(fields)) { Map replace = new HashMap<>(); for (String field : fields) { switch (metricDefineType) { case METRIC: - Optional metricItem = - metricResps.stream() - .filter(m -> m.getBizName().equalsIgnoreCase(field)) - .findFirst(); + Optional metricItem = metricResps.stream() + .filter(m -> m.getBizName().equalsIgnoreCase(field)).findFirst(); if (metricItem.isPresent()) { if (visitedMetric.contains(field)) { break; } - replace.put( - field, - generateDerivedMetric( - metricResps, - allFields, - allMeasures, - dimensionResps, - getExpr(metricItem.get()), - metricItem.get().getMetricDefineType(), - aggOption, - visitedMetric, - measures, - dimensions)); + replace.put(field, + generateDerivedMetric(metricResps, allFields, allMeasures, + dimensionResps, getExpr(metricItem.get()), + metricItem.get().getMetricDefineType(), aggOption, + visitedMetric, measures, dimensions)); visitedMetric.add(field); } break; @@ -356,10 +307,8 @@ public class SqlGenerateUtils { break; case FIELD: if (allFields.contains(field)) { - Optional dimensionItem = - dimensionResps.stream() - .filter(d -> d.getBizName().equals(field)) - .findFirst(); + Optional dimensionItem = dimensionResps.stream() + .filter(d -> d.getBizName().equals(field)).findFirst(); if (dimensionItem.isPresent()) { dimensions.add(field); } else { @@ -382,17 +331,11 @@ public class SqlGenerateUtils { public String getExpr(Measure measure, AggOption aggOption) { if (AggOperatorEnum.COUNT_DISTINCT.getOperator().equalsIgnoreCase(measure.getAgg())) { - return AggOption.NATIVE.equals(aggOption) - ? measure.getBizName() - : AggOperatorEnum.COUNT.getOperator() - + " ( " - + AggOperatorEnum.DISTINCT - + " " - + measure.getBizName() - + " ) "; + return AggOption.NATIVE.equals(aggOption) ? measure.getBizName() + : AggOperatorEnum.COUNT.getOperator() + " ( " + AggOperatorEnum.DISTINCT + " " + + measure.getBizName() + " ) "; } - return AggOption.NATIVE.equals(aggOption) - ? measure.getBizName() + return AggOption.NATIVE.equals(aggOption) ? measure.getBizName() : measure.getAgg() + " ( " + measure.getBizName() + " ) "; } diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlUtils.java index 306ea1498..04cf92616 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlUtils.java @@ -36,9 +36,11 @@ import static com.tencent.supersonic.common.pojo.Constants.AT_SYMBOL; @Component public class SqlUtils { - @Getter private Database database; + @Getter + private Database database; - @Autowired private JdbcDataSource jdbcDataSource; + @Autowired + private JdbcDataSource jdbcDataSource; @Value("${s2.source.result-limit:1000000}") private int resultLimit; @@ -46,9 +48,11 @@ public class SqlUtils { @Value("${s2.source.enable-query-log:false}") private boolean isQueryLogEnable; - @Getter private DataType dataTypeEnum; + @Getter + private DataType dataTypeEnum; - @Getter private JdbcDataSourceUtils jdbcDataSourceUtils; + @Getter + private JdbcDataSourceUtils jdbcDataSourceUtils; public SqlUtils() {} @@ -60,14 +64,10 @@ public class SqlUtils { public SqlUtils init(Database database) { return SqlUtilsBuilder.getBuilder() .withName(database.getId() + AT_SYMBOL + database.getName()) - .withType(database.getType()) - .withJdbcUrl(database.getUrl()) - .withUsername(database.getUsername()) - .withPassword(database.getPassword()) - .withJdbcDataSource(this.jdbcDataSource) - .withResultLimit(this.resultLimit) - .withIsQueryLogEnable(this.isQueryLogEnable) - .build(); + .withType(database.getType()).withJdbcUrl(database.getUrl()) + .withUsername(database.getUsername()).withPassword(database.getPassword()) + .withJdbcDataSource(this.jdbcDataSource).withResultLimit(this.resultLimit) + .withIsQueryLogEnable(this.isQueryLogEnable).build(); } public List> execute(String sql) throws ServerException { @@ -105,27 +105,25 @@ public class SqlUtils { getResult(sql, queryResultWithColumns, jdbcTemplate()); } - private SemanticQueryResp getResult( - String sql, SemanticQueryResp queryResultWithColumns, JdbcTemplate jdbcTemplate) { - jdbcTemplate.query( - sql, - rs -> { - if (null == rs) { - return queryResultWithColumns; - } + private SemanticQueryResp getResult(String sql, SemanticQueryResp queryResultWithColumns, + JdbcTemplate jdbcTemplate) { + jdbcTemplate.query(sql, rs -> { + if (null == rs) { + return queryResultWithColumns; + } - ResultSetMetaData metaData = rs.getMetaData(); - List queryColumns = new ArrayList<>(); - for (int i = 1; i <= metaData.getColumnCount(); i++) { - String key = metaData.getColumnLabel(i); - queryColumns.add(new QueryColumn(key, metaData.getColumnTypeName(i))); - } - queryResultWithColumns.setColumns(queryColumns); + ResultSetMetaData metaData = rs.getMetaData(); + List queryColumns = new ArrayList<>(); + for (int i = 1; i <= metaData.getColumnCount(); i++) { + String key = metaData.getColumnLabel(i); + queryColumns.add(new QueryColumn(key, metaData.getColumnTypeName(i))); + } + queryResultWithColumns.setColumns(queryColumns); - List> resultList = getAllData(rs, queryColumns); - queryResultWithColumns.setResultList(resultList); - return queryResultWithColumns; - }); + List> resultList = getAllData(rs, queryColumns); + queryResultWithColumns.setResultList(resultList); + return queryResultWithColumns; + }); return queryResultWithColumns; } @@ -226,14 +224,8 @@ public class SqlUtils { } public SqlUtils build() { - Database database = - Database.builder() - .name(this.name) - .type(this.type) - .url(this.jdbcUrl) - .username(this.username) - .password(this.password) - .build(); + Database database = Database.builder().name(this.name).type(this.type).url(this.jdbcUrl) + .username(this.username).password(this.password).build(); SqlUtils sqlUtils = new SqlUtils(database); sqlUtils.jdbcDataSource = this.jdbcDataSource; diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java index 6a3d0ae78..7e3920ad0 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SqlVariableParseUtils.java @@ -38,38 +38,32 @@ public class SqlVariableParseUtils { return sql; } // 1. handle default variable value - sqlVariables.forEach( - variable -> { - variables.put( - variable.getName().trim(), - getValues(variable.getValueType(), variable.getDefaultValues())); - }); + sqlVariables.forEach(variable -> { + variables.put(variable.getName().trim(), + getValues(variable.getValueType(), variable.getDefaultValues())); + }); // override by variable param if (!CollectionUtils.isEmpty(params)) { Map> map = sqlVariables.stream().collect(Collectors.groupingBy(SqlVariable::getName)); - params.forEach( - p -> { - if (map.containsKey(p.getName())) { - List list = map.get(p.getName()); - if (!CollectionUtils.isEmpty(list)) { - SqlVariable v = list.get(list.size() - 1); - variables.put( - p.getName().trim(), - getValue(v.getValueType(), p.getValue())); - } - } - }); + params.forEach(p -> { + if (map.containsKey(p.getName())) { + List list = map.get(p.getName()); + if (!CollectionUtils.isEmpty(list)) { + SqlVariable v = list.get(list.size() - 1); + variables.put(p.getName().trim(), getValue(v.getValueType(), p.getValue())); + } + } + }); } - variables.forEach( - (k, v) -> { - if (v instanceof List && ((List) v).size() > 0) { - v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString(); - } - variables.put(k, v); - }); + variables.forEach((k, v) -> { + if (v instanceof List && ((List) v).size() > 0) { + v = ((List) v).stream().collect(Collectors.joining(COMMA)).toString(); + } + variables.put(k, v); + }); return parse(sql, variables); } @@ -88,17 +82,12 @@ public class SqlVariableParseUtils { if (null != valueType) { switch (valueType) { case STRING: - return values.stream() - .map(String::valueOf) - .map( - s -> - s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE) - ? s - : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE)) + return values.stream().map(String::valueOf) + .map(s -> s.startsWith(APOSTROPHE) && s.endsWith(APOSTROPHE) ? s + : String.join(EMPTY, APOSTROPHE, s, APOSTROPHE)) .collect(Collectors.toList()); case EXPR: - values.stream() - .map(String::valueOf) + values.stream().map(String::valueOf) .forEach(SqlVariableParseUtils::checkSensitiveSql); return values.stream().map(String::valueOf).collect(Collectors.toList()); case NUMBER: @@ -115,11 +104,8 @@ public class SqlVariableParseUtils { if (null != valueType) { switch (valueType) { case STRING: - return String.join( - EMPTY, - value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE, - value, - value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE); + return String.join(EMPTY, value.startsWith(APOSTROPHE) ? EMPTY : APOSTROPHE, + value, value.endsWith(APOSTROPHE) ? EMPTY : APOSTROPHE); case NUMBER: case EXPR: default: diff --git a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java index 7dfece732..3bd63e87c 100644 --- a/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java +++ b/headless/core/src/main/java/com/tencent/supersonic/headless/core/utils/SysTimeDimensionBuilder.java @@ -17,8 +17,7 @@ public class SysTimeDimensionBuilder { // Defines the regular expression pattern for the time keyword private static final Pattern TIME_KEYWORD_PATTERN = - Pattern.compile( - "\\b(DATE|TIME|TIMESTAMP|YEAR|MONTH|DAY|HOUR|MINUTE|SECOND)\\b", + Pattern.compile("\\b(DATE|TIME|TIMESTAMP|YEAR|MONTH|DAY|HOUR|MINUTE|SECOND)\\b", Pattern.CASE_INSENSITIVE); public static void addSysTimeDimension(List dims, DbAdaptor engineAdaptor) { @@ -39,9 +38,8 @@ public class SysTimeDimensionBuilder { Dim dim = new Dim(); dim.setBizName(TimeDimensionEnum.DAY.getName()); dim.setType(DimensionType.partition_time.name()); - dim.setExpr( - generateTimeExpr( - timeDim, TimeDimensionEnum.DAY.name().toLowerCase(), engineAdaptor)); + dim.setExpr(generateTimeExpr(timeDim, TimeDimensionEnum.DAY.name().toLowerCase(), + engineAdaptor)); DimensionTimeTypeParams typeParams = new DimensionTimeTypeParams(); typeParams.setTimeGranularity(TimeDimensionEnum.DAY.name().toLowerCase()); typeParams.setIsPrimary("true"); @@ -53,9 +51,8 @@ public class SysTimeDimensionBuilder { Dim dim = new Dim(); dim.setBizName(TimeDimensionEnum.WEEK.getName()); dim.setType(DimensionType.partition_time.name()); - dim.setExpr( - generateTimeExpr( - timeDim, TimeDimensionEnum.WEEK.name().toLowerCase(), engineAdaptor)); + dim.setExpr(generateTimeExpr(timeDim, TimeDimensionEnum.WEEK.name().toLowerCase(), + engineAdaptor)); DimensionTimeTypeParams typeParams = new DimensionTimeTypeParams(); typeParams.setTimeGranularity(TimeDimensionEnum.WEEK.name().toLowerCase()); typeParams.setIsPrimary("false"); @@ -67,9 +64,8 @@ public class SysTimeDimensionBuilder { Dim dim = new Dim(); dim.setBizName(TimeDimensionEnum.MONTH.getName()); dim.setType(DimensionType.partition_time.name()); - dim.setExpr( - generateTimeExpr( - timeDim, TimeDimensionEnum.MONTH.name().toLowerCase(), engineAdaptor)); + dim.setExpr(generateTimeExpr(timeDim, TimeDimensionEnum.MONTH.name().toLowerCase(), + engineAdaptor)); DimensionTimeTypeParams typeParams = new DimensionTimeTypeParams(); typeParams.setTimeGranularity(TimeDimensionEnum.MONTH.name().toLowerCase()); typeParams.setIsPrimary("false"); diff --git a/headless/core/src/test/java/com/tencent/supersonic/chat/core/parser/aggregate/CalciteSqlParserTest.java b/headless/core/src/test/java/com/tencent/supersonic/chat/core/parser/aggregate/CalciteSqlParserTest.java index 30ee19cf7..5a3124bf8 100644 --- a/headless/core/src/test/java/com/tencent/supersonic/chat/core/parser/aggregate/CalciteSqlParserTest.java +++ b/headless/core/src/test/java/com/tencent/supersonic/chat/core/parser/aggregate/CalciteSqlParserTest.java @@ -11,462 +11,316 @@ public class CalciteSqlParserTest { @Test public void testCalciteSqlParser() throws Exception { - String json = - "{\n" - + " \"dataSetId\": 1,\n" - + " \"sql\": \"\",\n" - + " \"sourceId\": \"\",\n" - + " \"errMsg\": \"\",\n" - + " \"metricQueryParam\": {\n" - + " \"metrics\": [\n" - + " \"pv\"\n" - + " ],\n" - + " \"dimensions\": [\n" - + " \"sys_imp_date\"\n" - + " ],\n" - + " \"nativeQuery\": false\n" - + " },\n" - + " \"status\": 0,\n" - + " \"isS2SQL\": false,\n" - + " \"enableOptimize\": true,\n" - + " \"minMaxTime\": {\n" - + " \"left\": \"sys_imp_date\",\n" - + " \"middle\": \"2024-03-24\",\n" - + " \"right\": \"2024-03-18\"\n" - + " },\n" - + " \"dataSetSql\": \"SELECT sys_imp_date, SUM(pv) AS pv FROM t_1 WHERE " - + "sys_imp_date >= '2024-03-18' AND sys_imp_date <= '2024-03-24' GROUP BY sys_imp_date LIMIT 365\",\n" - + " \"dataSetAlias\": \"t_1\",\n" - + " \"dataSetSimplifySql\": \"\",\n" - + " \"enableLimitWrapper\": false,\n" - + " \"semanticModel\": {\n" - + " \"schemaKey\": \"VIEW_1\",\n" - + " \"metrics\": [\n" - + " {\n" - + " \"name\": \"pv\",\n" - + " \"owners\": [\n" - + " \"admin\"\n" - + " ],\n" - + " \"type\": \"ATOMIC\",\n" - + " \"metricTypeParams\": {\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"s2_pv_uv_statis_pv\",\n" - + " \"agg\": \"SUM\",\n" - + " \"constraint\": \"\"\n" - + " }\n" - + " ],\n" - + " \"isFieldMetric\": false,\n" - + " \"expr\": \"s2_pv_uv_statis_pv\"\n" - + " }\n" - + " },\n" - + " {\n" - + " \"name\": \"uv\",\n" - + " \"owners\": [\n" - + " \"admin\"\n" - + " ],\n" - + " \"type\": \"DERIVED\",\n" - + " \"metricTypeParams\": {\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"user_id\",\n" - + " \"expr\": \"user_id\"\n" - + " }\n" - + " ],\n" - + " \"isFieldMetric\": true,\n" - + " \"expr\": \"user_id\"\n" - + " }\n" - + " },\n" - + " {\n" - + " \"name\": \"pv_avg\",\n" - + " \"owners\": [\n" - + " \"admin\"\n" - + " ],\n" - + " \"type\": \"DERIVED\",\n" - + " \"metricTypeParams\": {\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"pv\",\n" - + " \"expr\": \"pv\"\n" - + " },\n" - + " {\n" - + " \"name\": \"uv\",\n" - + " \"expr\": \"uv\"\n" - + " }\n" - + " ],\n" - + " \"isFieldMetric\": true,\n" - + " \"expr\": \"pv\"\n" - + " }\n" - + " },\n" - + " {\n" - + " \"name\": \"stay_hours\",\n" - + " \"owners\": [\n" - + " \"admin\"\n" - + " ],\n" - + " \"type\": \"ATOMIC\",\n" - + " \"metricTypeParams\": {\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"s2_stay_time_statis_stay_hours\",\n" - + " \"agg\": \"SUM\",\n" - + " \"constraint\": \"\"\n" - + " }\n" - + " ],\n" - + " \"isFieldMetric\": false,\n" - + " \"expr\": \"s2_stay_time_statis_stay_hours\"\n" - + " }\n" - + " }\n" - + " ],\n" - + " \"datasourceMap\": {\n" - + " \"user_department\": {\n" - + " \"id\": 1,\n" - + " \"name\": \"user_department\",\n" - + " \"sourceId\": 1,\n" - + " \"type\": \"h2\",\n" - + " \"sqlQuery\": \"select user_name,department from s2_user_department\",\n" - + " \"identifiers\": [\n" - + " {\n" - + " \"name\": \"user_name\",\n" - + " \"type\": \"primary\"\n" - + " }\n" - + " ],\n" - + " \"dimensions\": [\n" - + " {\n" - + " \"name\": \"department\",\n" - + " \"type\": \"categorical\",\n" - + " \"expr\": \"department\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"department\"\n" - + " }\n" - + " ],\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"user_department_internal_cnt\",\n" - + " \"agg\": \"count\",\n" - + " \"expr\": \"user_name\"\n" - + " },\n" - + " {\n" - + " \"name\": \"user_name\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"user_name\"\n" - + " },\n" - + " {\n" - + " \"name\": \"department\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"department\"\n" - + " }\n" - + " ],\n" - + " \"aggTime\": \"none\"\n" - + " },\n" - + " \"s2_pv_uv_statis\": {\n" - + " \"id\": 2,\n" - + " \"name\": \"s2_pv_uv_statis\",\n" - + " \"sourceId\": 1,\n" - + " \"type\": \"h2\",\n" - + " \"sqlQuery\": \"SELECT imp_date, user_name, page, 1 as pv, user_name as user_id " - + "FROM s2_pv_uv_statis\",\n" - + " \"identifiers\": [\n" - + " {\n" - + " \"name\": \"user_name\",\n" - + " \"type\": \"primary\"\n" - + " }\n" - + " ],\n" - + " \"dimensions\": [\n" - + " {\n" - + " \"name\": \"imp_date\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"imp_date\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"true\",\n" - + " \"timeGranularity\": \"day\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"imp_date\"\n" - + " },\n" - + " {\n" - + " \"name\": \"page\",\n" - + " \"type\": \"categorical\",\n" - + " \"expr\": \"page\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"page\"\n" - + " },\n" - + " {\n" - + " \"name\": \"sys_imp_date\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"imp_date\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"true\",\n" - + " \"timeGranularity\": \"day\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"sys_imp_date\"\n" - + " },\n" - + " {\n" - + " \"name\": \"sys_imp_week\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"DATE_TRUNC('week',imp_date)\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"false\",\n" - + " \"timeGranularity\": \"week\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"sys_imp_week\"\n" - + " },\n" - + " {\n" - + " \"name\": \"sys_imp_month\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"FORMATDATETIME(PARSEDATETIME" - + "(imp_date, 'yyyy-MM-dd'),'yyyy-MM') \",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"false\",\n" - + " \"timeGranularity\": \"month\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"sys_imp_month\"\n" - + " }\n" - + " ],\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"s2_pv_uv_statis_pv\",\n" - + " \"agg\": \"SUM\",\n" - + " \"expr\": \"pv\"\n" - + " },\n" - + " {\n" - + " \"name\": \"s2_pv_uv_statis_user_id\",\n" - + " \"agg\": \"SUM\",\n" - + " \"expr\": \"user_id\"\n" - + " },\n" - + " {\n" - + " \"name\": \"s2_pv_uv_statis_internal_cnt\",\n" - + " \"agg\": \"count\",\n" - + " \"expr\": \"user_name\"\n" - + " },\n" - + " {\n" - + " \"name\": \"user_name\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"user_name\"\n" - + " },\n" - + " {\n" - + " \"name\": \"imp_date\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"imp_date\"\n" - + " },\n" - + " {\n" - + " \"name\": \"page\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"page\"\n" - + " },\n" - + " {\n" - + " \"name\": \"pv\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"pv\"\n" - + " },\n" - + " {\n" - + " \"name\": \"user_id\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"user_id\"\n" - + " }\n" - + " ],\n" - + " \"aggTime\": \"day\"\n" - + " },\n" - + " \"s2_stay_time_statis\": {\n" - + " \"id\": 3,\n" - + " \"name\": \"s2_stay_time_statis\",\n" - + " \"sourceId\": 1,\n" - + " \"type\": \"h2\",\n" - + " \"sqlQuery\": \"select imp_date,user_name,stay_hours" - + ",page from s2_stay_time_statis\",\n" - + " \"identifiers\": [\n" - + " {\n" - + " \"name\": \"user_name\",\n" - + " \"type\": \"primary\"\n" - + " }\n" - + " ],\n" - + " \"dimensions\": [\n" - + " {\n" - + " \"name\": \"imp_date\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"imp_date\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"true\",\n" - + " \"timeGranularity\": \"day\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"imp_date\"\n" - + " },\n" - + " {\n" - + " \"name\": \"page\",\n" - + " \"type\": \"categorical\",\n" - + " \"expr\": \"page\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"page\"\n" - + " },\n" - + " {\n" - + " \"name\": \"sys_imp_date\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"imp_date\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"true\",\n" - + " \"timeGranularity\": \"day\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"sys_imp_date\"\n" - + " },\n" - + " {\n" - + " \"name\": \"sys_imp_week\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"DATE_TRUNC('week',imp_date)\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"false\",\n" - + " \"timeGranularity\": \"week\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"sys_imp_week\"\n" - + " },\n" - + " {\n" - + " \"name\": \"sys_imp_month\",\n" - + " \"type\": \"time\",\n" - + " \"expr\": \"FORMATDATETIME(PARSEDATETIME" - + "(imp_date, 'yyyy-MM-dd'),'yyyy-MM') \",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " \"isPrimary\": \"false\",\n" - + " \"timeGranularity\": \"month\"\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"sys_imp_month\"\n" - + " }\n" - + " ],\n" - + " \"measures\": [\n" - + " {\n" - + " \"name\": \"s2_stay_time_statis_stay_hours\",\n" - + " \"agg\": \"SUM\",\n" - + " \"expr\": \"stay_hours\"\n" - + " },\n" - + " {\n" - + " \"name\": \"s2_stay_time_statis_internal_cnt\",\n" - + " \"agg\": \"count\",\n" - + " \"expr\": \"user_name\"\n" - + " },\n" - + " {\n" - + " \"name\": \"user_name\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"user_name\"\n" - + " },\n" - + " {\n" - + " \"name\": \"imp_date\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"imp_date\"\n" - + " },\n" - + " {\n" - + " \"name\": \"page\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"page\"\n" - + " },\n" - + " {\n" - + " \"name\": \"stay_hours\",\n" - + " \"agg\": \"\",\n" - + " \"expr\": \"stay_hours\"\n" - + " }\n" - + " ],\n" - + " \"aggTime\": \"day\"\n" - + " }\n" - + " },\n" - + " \"dimensionMap\": {\n" - + " \"user_department\": [\n" - + " {\n" - + " \"name\": \"department\",\n" - + " \"owners\": \"admin\",\n" - + " \"type\": \"categorical\",\n" - + " \"expr\": \"department\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"department\"\n" - + " }\n" - + " ],\n" - + " \"s2_pv_uv_statis\": [\n" - + " ],\n" - + " \"s2_stay_time_statis\": [\n" - + " {\n" - + " \"name\": \"page\",\n" - + " \"owners\": \"admin\",\n" - + " \"type\": \"categorical\",\n" - + " \"expr\": \"page\",\n" - + " \"dimensionTimeTypeParams\": {\n" - + " },\n" - + " \"dataType\": \"UNKNOWN\",\n" - + " \"bizName\": \"page\"\n" - + " }\n" - + " ]\n" - + " },\n" - + " \"materializationList\": [\n" - + " ],\n" - + " \"joinRelations\": [\n" - + " {\n" - + " \"id\": 1,\n" - + " \"left\": \"user_department\",\n" - + " \"right\": \"s2_pv_uv_statis\",\n" - + " \"joinType\": \"left join\",\n" - + " \"joinCondition\": [\n" - + " {\n" - + " \"left\": \"user_name\",\n" - + " \"middle\": \"=\",\n" - + " \"right\": \"user_name\"\n" - + " }\n" - + " ]\n" - + " },\n" - + " {\n" - + " \"id\": 2,\n" - + " \"left\": \"user_department\",\n" - + " \"right\": \"s2_stay_time_statis\",\n" - + " \"joinType\": \"left join\",\n" - + " \"joinCondition\": [\n" - + " {\n" - + " \"left\": \"user_name\",\n" - + " \"middle\": \"=\",\n" - + " \"right\": \"user_name\"\n" - + " }\n" - + " ]\n" - + " }\n" - + " ],\n" - + " \"database\": {\n" - + " \"id\": 1,\n" - + " \"name\": \"数据实例\",\n" - + " \"description\": \"样例数据库实例\",\n" - + " \"url\": \"jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false\",\n" - + " \"username\": \"root\",\n" - + " \"password\": \"semantic\",\n" - + " \"type\": \"h2\",\n" - + " \"connectInfo\": {\n" - + " \"url\": \"jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false\",\n" - + " \"userName\": \"root\",\n" - + " \"password\": \"semantic\"\n" - + " },\n" - + " \"admins\": [\n" - + " ],\n" - + " \"viewers\": [\n" - + " ],\n" - + " \"createdBy\": \"admin\",\n" - + " \"updatedBy\": \"admin\",\n" - + " \"createdAt\": 1711367511146,\n" - + " \"updatedAt\": 1711367511146\n" - + " }\n" - + " }\n" - + "}"; + String json = "{\n" + " \"dataSetId\": 1,\n" + " \"sql\": \"\",\n" + + " \"sourceId\": \"\",\n" + " \"errMsg\": \"\",\n" + + " \"metricQueryParam\": {\n" + " \"metrics\": [\n" + + " \"pv\"\n" + " ],\n" + " \"dimensions\": [\n" + + " \"sys_imp_date\"\n" + " ],\n" + + " \"nativeQuery\": false\n" + " },\n" + " \"status\": 0,\n" + + " \"isS2SQL\": false,\n" + " \"enableOptimize\": true,\n" + + " \"minMaxTime\": {\n" + " \"left\": \"sys_imp_date\",\n" + + " \"middle\": \"2024-03-24\",\n" + " \"right\": \"2024-03-18\"\n" + + " },\n" + + " \"dataSetSql\": \"SELECT sys_imp_date, SUM(pv) AS pv FROM t_1 WHERE " + + "sys_imp_date >= '2024-03-18' AND sys_imp_date <= '2024-03-24' GROUP BY sys_imp_date LIMIT 365\",\n" + + " \"dataSetAlias\": \"t_1\",\n" + " \"dataSetSimplifySql\": \"\",\n" + + " \"enableLimitWrapper\": false,\n" + " \"semanticModel\": {\n" + + " \"schemaKey\": \"VIEW_1\",\n" + " \"metrics\": [\n" + + " {\n" + " \"name\": \"pv\",\n" + + " \"owners\": [\n" + " \"admin\"\n" + + " ],\n" + " \"type\": \"ATOMIC\",\n" + + " \"metricTypeParams\": {\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"s2_pv_uv_statis_pv\",\n" + + " \"agg\": \"SUM\",\n" + + " \"constraint\": \"\"\n" + + " }\n" + " ],\n" + + " \"isFieldMetric\": false,\n" + + " \"expr\": \"s2_pv_uv_statis_pv\"\n" + " }\n" + + " },\n" + " {\n" + " \"name\": \"uv\",\n" + + " \"owners\": [\n" + " \"admin\"\n" + + " ],\n" + " \"type\": \"DERIVED\",\n" + + " \"metricTypeParams\": {\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"user_id\",\n" + + " \"expr\": \"user_id\"\n" + + " }\n" + " ],\n" + + " \"isFieldMetric\": true,\n" + + " \"expr\": \"user_id\"\n" + " }\n" + + " },\n" + " {\n" + " \"name\": \"pv_avg\",\n" + + " \"owners\": [\n" + " \"admin\"\n" + + " ],\n" + " \"type\": \"DERIVED\",\n" + + " \"metricTypeParams\": {\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"pv\",\n" + + " \"expr\": \"pv\"\n" + " },\n" + + " {\n" + " \"name\": \"uv\",\n" + + " \"expr\": \"uv\"\n" + " }\n" + + " ],\n" + " \"isFieldMetric\": true,\n" + + " \"expr\": \"pv\"\n" + " }\n" + + " },\n" + " {\n" + + " \"name\": \"stay_hours\",\n" + " \"owners\": [\n" + + " \"admin\"\n" + " ],\n" + + " \"type\": \"ATOMIC\",\n" + + " \"metricTypeParams\": {\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"s2_stay_time_statis_stay_hours\",\n" + + " \"agg\": \"SUM\",\n" + + " \"constraint\": \"\"\n" + + " }\n" + " ],\n" + + " \"isFieldMetric\": false,\n" + + " \"expr\": \"s2_stay_time_statis_stay_hours\"\n" + + " }\n" + " }\n" + " ],\n" + + " \"datasourceMap\": {\n" + " \"user_department\": {\n" + + " \"id\": 1,\n" + + " \"name\": \"user_department\",\n" + + " \"sourceId\": 1,\n" + " \"type\": \"h2\",\n" + + " \"sqlQuery\": \"select user_name,department from s2_user_department\",\n" + + " \"identifiers\": [\n" + " {\n" + + " \"name\": \"user_name\",\n" + + " \"type\": \"primary\"\n" + " }\n" + + " ],\n" + " \"dimensions\": [\n" + + " {\n" + " \"name\": \"department\",\n" + + " \"type\": \"categorical\",\n" + + " \"expr\": \"department\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"department\"\n" + + " }\n" + " ],\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"user_department_internal_cnt\",\n" + + " \"agg\": \"count\",\n" + + " \"expr\": \"user_name\"\n" + " },\n" + + " {\n" + " \"name\": \"user_name\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"user_name\"\n" + " },\n" + + " {\n" + " \"name\": \"department\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"department\"\n" + " }\n" + + " ],\n" + " \"aggTime\": \"none\"\n" + + " },\n" + " \"s2_pv_uv_statis\": {\n" + + " \"id\": 2,\n" + + " \"name\": \"s2_pv_uv_statis\",\n" + + " \"sourceId\": 1,\n" + " \"type\": \"h2\",\n" + + " \"sqlQuery\": \"SELECT imp_date, user_name, page, 1 as pv, user_name as user_id " + + "FROM s2_pv_uv_statis\",\n" + " \"identifiers\": [\n" + + " {\n" + " \"name\": \"user_name\",\n" + + " \"type\": \"primary\"\n" + " }\n" + + " ],\n" + " \"dimensions\": [\n" + + " {\n" + " \"name\": \"imp_date\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"imp_date\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"true\",\n" + + " \"timeGranularity\": \"day\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"imp_date\"\n" + " },\n" + + " {\n" + " \"name\": \"page\",\n" + + " \"type\": \"categorical\",\n" + + " \"expr\": \"page\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"page\"\n" + " },\n" + + " {\n" + + " \"name\": \"sys_imp_date\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"imp_date\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"true\",\n" + + " \"timeGranularity\": \"day\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"sys_imp_date\"\n" + + " },\n" + " {\n" + + " \"name\": \"sys_imp_week\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"DATE_TRUNC('week',imp_date)\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"false\",\n" + + " \"timeGranularity\": \"week\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"sys_imp_week\"\n" + + " },\n" + " {\n" + + " \"name\": \"sys_imp_month\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"FORMATDATETIME(PARSEDATETIME" + + "(imp_date, 'yyyy-MM-dd'),'yyyy-MM') \",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"false\",\n" + + " \"timeGranularity\": \"month\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"sys_imp_month\"\n" + + " }\n" + " ],\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"s2_pv_uv_statis_pv\",\n" + + " \"agg\": \"SUM\",\n" + + " \"expr\": \"pv\"\n" + " },\n" + + " {\n" + + " \"name\": \"s2_pv_uv_statis_user_id\",\n" + + " \"agg\": \"SUM\",\n" + + " \"expr\": \"user_id\"\n" + " },\n" + + " {\n" + + " \"name\": \"s2_pv_uv_statis_internal_cnt\",\n" + + " \"agg\": \"count\",\n" + + " \"expr\": \"user_name\"\n" + " },\n" + + " {\n" + " \"name\": \"user_name\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"user_name\"\n" + " },\n" + + " {\n" + " \"name\": \"imp_date\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"imp_date\"\n" + " },\n" + + " {\n" + " \"name\": \"page\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"page\"\n" + " },\n" + + " {\n" + " \"name\": \"pv\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"pv\"\n" + " },\n" + + " {\n" + " \"name\": \"user_id\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"user_id\"\n" + " }\n" + + " ],\n" + " \"aggTime\": \"day\"\n" + + " },\n" + " \"s2_stay_time_statis\": {\n" + + " \"id\": 3,\n" + + " \"name\": \"s2_stay_time_statis\",\n" + + " \"sourceId\": 1,\n" + " \"type\": \"h2\",\n" + + " \"sqlQuery\": \"select imp_date,user_name,stay_hours" + + ",page from s2_stay_time_statis\",\n" + " \"identifiers\": [\n" + + " {\n" + " \"name\": \"user_name\",\n" + + " \"type\": \"primary\"\n" + " }\n" + + " ],\n" + " \"dimensions\": [\n" + + " {\n" + " \"name\": \"imp_date\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"imp_date\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"true\",\n" + + " \"timeGranularity\": \"day\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"imp_date\"\n" + " },\n" + + " {\n" + " \"name\": \"page\",\n" + + " \"type\": \"categorical\",\n" + + " \"expr\": \"page\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"page\"\n" + " },\n" + + " {\n" + + " \"name\": \"sys_imp_date\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"imp_date\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"true\",\n" + + " \"timeGranularity\": \"day\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"sys_imp_date\"\n" + + " },\n" + " {\n" + + " \"name\": \"sys_imp_week\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"DATE_TRUNC('week',imp_date)\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"false\",\n" + + " \"timeGranularity\": \"week\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"sys_imp_week\"\n" + + " },\n" + " {\n" + + " \"name\": \"sys_imp_month\",\n" + + " \"type\": \"time\",\n" + + " \"expr\": \"FORMATDATETIME(PARSEDATETIME" + + "(imp_date, 'yyyy-MM-dd'),'yyyy-MM') \",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " \"isPrimary\": \"false\",\n" + + " \"timeGranularity\": \"month\"\n" + + " },\n" + + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"sys_imp_month\"\n" + + " }\n" + " ],\n" + + " \"measures\": [\n" + " {\n" + + " \"name\": \"s2_stay_time_statis_stay_hours\",\n" + + " \"agg\": \"SUM\",\n" + + " \"expr\": \"stay_hours\"\n" + " },\n" + + " {\n" + + " \"name\": \"s2_stay_time_statis_internal_cnt\",\n" + + " \"agg\": \"count\",\n" + + " \"expr\": \"user_name\"\n" + " },\n" + + " {\n" + " \"name\": \"user_name\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"user_name\"\n" + " },\n" + + " {\n" + " \"name\": \"imp_date\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"imp_date\"\n" + " },\n" + + " {\n" + " \"name\": \"page\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"page\"\n" + " },\n" + + " {\n" + " \"name\": \"stay_hours\",\n" + + " \"agg\": \"\",\n" + + " \"expr\": \"stay_hours\"\n" + " }\n" + + " ],\n" + " \"aggTime\": \"day\"\n" + + " }\n" + " },\n" + " \"dimensionMap\": {\n" + + " \"user_department\": [\n" + " {\n" + + " \"name\": \"department\",\n" + + " \"owners\": \"admin\",\n" + + " \"type\": \"categorical\",\n" + + " \"expr\": \"department\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " },\n" + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"department\"\n" + " }\n" + + " ],\n" + " \"s2_pv_uv_statis\": [\n" + " ],\n" + + " \"s2_stay_time_statis\": [\n" + " {\n" + + " \"name\": \"page\",\n" + + " \"owners\": \"admin\",\n" + + " \"type\": \"categorical\",\n" + + " \"expr\": \"page\",\n" + + " \"dimensionTimeTypeParams\": {\n" + + " },\n" + " \"dataType\": \"UNKNOWN\",\n" + + " \"bizName\": \"page\"\n" + " }\n" + + " ]\n" + " },\n" + " \"materializationList\": [\n" + + " ],\n" + " \"joinRelations\": [\n" + " {\n" + + " \"id\": 1,\n" + + " \"left\": \"user_department\",\n" + + " \"right\": \"s2_pv_uv_statis\",\n" + + " \"joinType\": \"left join\",\n" + + " \"joinCondition\": [\n" + " {\n" + + " \"left\": \"user_name\",\n" + + " \"middle\": \"=\",\n" + + " \"right\": \"user_name\"\n" + " }\n" + + " ]\n" + " },\n" + " {\n" + + " \"id\": 2,\n" + + " \"left\": \"user_department\",\n" + + " \"right\": \"s2_stay_time_statis\",\n" + + " \"joinType\": \"left join\",\n" + + " \"joinCondition\": [\n" + " {\n" + + " \"left\": \"user_name\",\n" + + " \"middle\": \"=\",\n" + + " \"right\": \"user_name\"\n" + " }\n" + + " ]\n" + " }\n" + " ],\n" + + " \"database\": {\n" + " \"id\": 1,\n" + + " \"name\": \"数据实例\",\n" + + " \"description\": \"样例数据库实例\",\n" + + " \"url\": \"jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false\",\n" + + " \"username\": \"root\",\n" + + " \"password\": \"semantic\",\n" + " \"type\": \"h2\",\n" + + " \"connectInfo\": {\n" + + " \"url\": \"jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false\",\n" + + " \"userName\": \"root\",\n" + + " \"password\": \"semantic\"\n" + " },\n" + + " \"admins\": [\n" + " ],\n" + + " \"viewers\": [\n" + " ],\n" + + " \"createdBy\": \"admin\",\n" + + " \"updatedBy\": \"admin\",\n" + + " \"createdAt\": 1711367511146,\n" + + " \"updatedAt\": 1711367511146\n" + " }\n" + " }\n" + "}"; QueryStatement queryStatement = JSON.parseObject(json, QueryStatement.class); CalciteQueryParser calciteSqlParser = new CalciteQueryParser(); calciteSqlParser.parse(queryStatement, AggOption.DEFAULT); - Assert.assertEquals( - queryStatement.getSql().trim().replaceAll("\\s+", ""), - "SELECT`imp_date`AS`sys_imp_date`,SUM(1)AS`pv`" - + "FROM" - + "`s2_pv_uv_statis`" + Assert.assertEquals(queryStatement.getSql().trim().replaceAll("\\s+", ""), + "SELECT`imp_date`AS`sys_imp_date`,SUM(1)AS`pv`" + "FROM" + "`s2_pv_uv_statis`" + "GROUPBY`imp_date`,`imp_date`"); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/ApiHeaderCheck.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/ApiHeaderCheck.java index 195b6d813..4bb9ab58a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/ApiHeaderCheck.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/ApiHeaderCheck.java @@ -7,4 +7,5 @@ import java.lang.annotation.Target; @Target({ElementType.PARAMETER, ElementType.METHOD}) @Retention(RetentionPolicy.RUNTIME) -public @interface ApiHeaderCheck {} +public @interface ApiHeaderCheck { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java index 4979d2369..739175f29 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/annotation/S2DataPermission.java @@ -9,4 +9,5 @@ import java.lang.annotation.Target; @Target(ElementType.METHOD) @Retention(RetentionPolicy.RUNTIME) @Documented -public @interface S2DataPermission {} +public @interface S2DataPermission { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/ApiHeaderCheckAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/ApiHeaderCheckAspect.java index 6fc126fdd..98e3e02eb 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/ApiHeaderCheckAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/ApiHeaderCheckAspect.java @@ -30,7 +30,8 @@ public class ApiHeaderCheckAspect { private static final String SIGNATURE = "signature"; - @Autowired private AppService appService; + @Autowired + private AppService appService; @Pointcut("@annotation(com.tencent.supersonic.headless.server.annotation.ApiHeaderCheck)") private void apiPermissionCheck() {} @@ -63,12 +64,8 @@ public class ApiHeaderCheckAspect { if (!AppStatus.ONLINE.equals(appDetailResp.getAppStatus())) { throw new InvalidArgumentException("该应用暂时为非在线状态"); } - Pair checkResult = - SignatureUtils.isValidSignature( - appId, - appDetailResp.getAppSecret(), - Long.parseLong(timestampStr), - signature); + Pair checkResult = SignatureUtils.isValidSignature(appId, + appDetailResp.getAppSecret(), Long.parseLong(timestampStr), signature); if (!checkResult.first) { throw new InvalidArgumentException(checkResult.second); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java index 1edeb460c..feffc04dc 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/DimValueAspect.java @@ -44,10 +44,10 @@ public class DimValueAspect { @Value("${s2.dimension.value.map.enable:true}") private Boolean dimensionValueMapEnable; - @Autowired private DimensionService dimensionService; + @Autowired + private DimensionService dimensionService; - @Around( - "execution(* com.tencent.supersonic.headless.server.facade.service.SemanticLayerService.queryByReq(..))") + @Around("execution(* com.tencent.supersonic.headless.server.facade.service.SemanticLayerService.queryByReq(..))") public Object handleDimValue(ProceedingJoinPoint joinPoint) throws Throwable { if (!dimensionValueMapEnable) { log.debug("dimensionValueMapEnable is false, skip dimensionValueMap"); @@ -108,22 +108,14 @@ public class DimValueAspect { } // consider '=' filter if (expression.getOperator().equals(FilterOperatorEnum.EQUALS.getValue())) { - dimension.getDimValueMaps().stream() - .forEach( - dimValue -> { - if (!CollectionUtils.isEmpty(dimValue.getAlias()) - && dimValue.getAlias() - .contains( - expression - .getFieldValue() - .toString())) { - getFiledNameToValueMap( - filedNameToValueMap, - expression.getFieldValue().toString(), - dimValue.getTechName(), - expression.getFieldName()); - } - }); + dimension.getDimValueMaps().stream().forEach(dimValue -> { + if (!CollectionUtils.isEmpty(dimValue.getAlias()) && dimValue.getAlias() + .contains(expression.getFieldValue().toString())) { + getFiledNameToValueMap(filedNameToValueMap, + expression.getFieldValue().toString(), dimValue.getTechName(), + expression.getFieldName()); + } + }); } // consider 'in' filter,each element needs to judge. replaceInCondition(expression, dimension, filedNameToValueMap); @@ -141,9 +133,7 @@ public class DimValueAspect { return queryResultWithColumns; } - public void replaceInCondition( - FieldExpression expression, - DimensionResp dimension, + public void replaceInCondition(FieldExpression expression, DimensionResp dimension, Map> filedNameToValueMap) { if (expression.getOperator().equals(FilterOperatorEnum.IN.getValue())) { String fieldValue = JsonUtil.toString(expression.getFieldValue()); @@ -165,27 +155,20 @@ public class DimValueAspect { } } if (!revisedValues.equals(values)) { - getFiledNameToValueMap( - filedNameToValueMap, - JsonUtil.toString(values), - JsonUtil.toString(revisedValues), - expression.getFieldName()); + getFiledNameToValueMap(filedNameToValueMap, JsonUtil.toString(values), + JsonUtil.toString(revisedValues), expression.getFieldName()); } } } - public void getFiledNameToValueMap( - Map> filedNameToValueMap, - String oldValue, - String newValue, - String fieldName) { + public void getFiledNameToValueMap(Map> filedNameToValueMap, + String oldValue, String newValue, String fieldName) { Map map = new HashMap<>(); map.put(oldValue, newValue); filedNameToValueMap.put(fieldName, map); } - private void rewriteDimValue( - SemanticQueryResp semanticQueryResp, + private void rewriteDimValue(SemanticQueryResp semanticQueryResp, Map> dimAndTechNameAndBizNamePair) { if (!selectDimValueMap(semanticQueryResp.getColumns(), dimAndTechNameAndBizNamePair)) { return; @@ -209,8 +192,7 @@ public class DimValueAspect { } } - private boolean selectDimValueMap( - List columns, + private boolean selectDimValueMap(List columns, Map> dimAndTechNameAndBizNamePair) { if (CollectionUtils.isEmpty(dimAndTechNameAndBizNamePair) || CollectionUtils.isEmpty(dimAndTechNameAndBizNamePair)) { @@ -225,8 +207,8 @@ public class DimValueAspect { return false; } - private void rewriteFilter( - List dimensionFilters, Map> aliasAndTechNamePair) { + private void rewriteFilter(List dimensionFilters, + Map> aliasAndTechNamePair) { for (Filter filter : dimensionFilters) { if (Objects.isNull(filter)) { continue; @@ -283,18 +265,15 @@ public class DimValueAspect { continue; } if (StringUtils.isNotEmpty(dimValueMap.getBizName())) { - aliasAndBizNameToTechName.put( - dimValueMap.getBizName(), dimValueMap.getTechName()); + aliasAndBizNameToTechName.put(dimValueMap.getBizName(), + dimValueMap.getTechName()); } if (!CollectionUtils.isEmpty(dimValueMap.getAlias())) { - dimValueMap.getAlias().stream() - .forEach( - alias -> { - if (StringUtils.isNotEmpty(alias)) { - aliasAndBizNameToTechName.put( - alias, dimValueMap.getTechName()); - } - }); + dimValueMap.getAlias().stream().forEach(alias -> { + if (StringUtils.isNotEmpty(alias)) { + aliasAndBizNameToTechName.put(alias, dimValueMap.getTechName()); + } + }); } } @@ -339,8 +318,7 @@ public class DimValueAspect { } private boolean needSkipDimension(DimensionResp dimension) { - return Objects.isNull(dimension) - || StringUtils.isEmpty(dimension.getBizName()) + return Objects.isNull(dimension) || StringUtils.isEmpty(dimension.getBizName()) || CollectionUtils.isEmpty(dimension.getDimValueMaps()); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java index f7879cb94..db75d99f4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/aspect/S2DataPermissionAspect.java @@ -52,10 +52,14 @@ import java.util.stream.Collectors; @Slf4j public class S2DataPermissionAspect { - @Autowired private QueryStructUtils queryStructUtils; - @Autowired private ModelService modelService; - @Autowired private SchemaService schemaService; - @Autowired private AuthService authService; + @Autowired + private QueryStructUtils queryStructUtils; + @Autowired + private ModelService modelService; + @Autowired + private SchemaService schemaService; + @Autowired + private AuthService authService; @Pointcut("@annotation(com.tencent.supersonic.headless.server.annotation.S2DataPermission)") private void s2PermissionCheck() {} @@ -109,26 +113,19 @@ public class S2DataPermissionAspect { return result; } - private void checkColPermission( - SemanticQueryReq semanticQueryReq, - AuthorizedResourceResp authorizedResource, - Set modelIds, + private void checkColPermission(SemanticQueryReq semanticQueryReq, + AuthorizedResourceResp authorizedResource, Set modelIds, SemanticSchemaResp semanticSchemaResp) { // get high sensitive fields in query Set bizNamesInQueryReq = getBizNameInQueryReq(semanticQueryReq, semanticSchemaResp); Set sensitiveBizNamesByModel = getHighSensitiveBizNamesByModelId(semanticSchemaResp); - Set sensitiveBizNameInQuery = - bizNamesInQueryReq - .parallelStream() - .filter(sensitiveBizNamesByModel::contains) - .collect(Collectors.toSet()); + Set sensitiveBizNameInQuery = bizNamesInQueryReq.parallelStream() + .filter(sensitiveBizNamesByModel::contains).collect(Collectors.toSet()); // get high sensitive field cur user has been authed - Set sensitiveBizNameUserAuthed = - authorizedResource.getAuthResList().stream() - .map(AuthRes::getName) - .collect(Collectors.toSet()); + Set sensitiveBizNameUserAuthed = authorizedResource.getAuthResList().stream() + .map(AuthRes::getName).collect(Collectors.toSet()); sensitiveBizNameInQuery.removeAll(sensitiveBizNameUserAuthed); if (!CollectionUtils.isEmpty(sensitiveBizNameInQuery)) { Set sensitiveResNames = @@ -140,8 +137,8 @@ public class S2DataPermissionAspect { } } - private Set getModelIdInQuery( - SemanticQueryReq semanticQueryReq, SemanticSchemaResp semanticSchemaResp) { + private Set getModelIdInQuery(SemanticQueryReq semanticQueryReq, + SemanticSchemaResp semanticSchemaResp) { if (semanticQueryReq instanceof QuerySqlReq) { QuerySqlReq querySqlReq = (QuerySqlReq) semanticQueryReq; return queryStructUtils.getModelIdFromSql(querySqlReq, semanticSchemaResp); @@ -153,8 +150,8 @@ public class S2DataPermissionAspect { return Sets.newHashSet(); } - private void checkRowPermission( - SemanticQueryReq queryReq, AuthorizedResourceResp authorizedResource) { + private void checkRowPermission(SemanticQueryReq queryReq, + AuthorizedResourceResp authorizedResource) { if (queryReq instanceof QuerySqlReq) { doRowPermission((QuerySqlReq) queryReq, authorizedResource); } @@ -163,8 +160,8 @@ public class S2DataPermissionAspect { } } - private Set getBizNameInQueryReq( - SemanticQueryReq queryReq, SemanticSchemaResp semanticSchemaResp) { + private Set getBizNameInQueryReq(SemanticQueryReq queryReq, + SemanticSchemaResp semanticSchemaResp) { if (queryReq instanceof QuerySqlReq) { return queryStructUtils.getBizNameFromSql((QuerySqlReq) queryReq, semanticSchemaResp); } @@ -181,8 +178,8 @@ public class S2DataPermissionAspect { return schemaService.fetchSemanticSchema(filter); } - private void doRowPermission( - QuerySqlReq querySqlReq, AuthorizedResourceResp authorizedResource) { + private void doRowPermission(QuerySqlReq querySqlReq, + AuthorizedResourceResp authorizedResource) { log.debug("start doRowPermission logic"); StringJoiner joiner = new StringJoiner(" OR "); List dimensionFilters = new ArrayList<>(); @@ -196,14 +193,11 @@ public class S2DataPermissionAspect { return; } - dimensionFilters.stream() - .forEach( - filter -> { - if (StringUtils.isNotEmpty(filter) - && StringUtils.isNotEmpty(filter.trim())) { - joiner.add(" ( " + filter + " ) "); - } - }); + dimensionFilters.stream().forEach(filter -> { + if (StringUtils.isNotEmpty(filter) && StringUtils.isNotEmpty(filter.trim())) { + joiner.add(" ( " + filter + " ) "); + } + }); try { Expression expression = CCJSqlParserUtil.parseCondExpression(" ( " + joiner + " ) "); if (StringUtils.isNotEmpty(joiner.toString())) { @@ -217,8 +211,8 @@ public class S2DataPermissionAspect { } } - private void doRowPermission( - QueryStructReq queryStructReq, AuthorizedResourceResp authorizedResource) { + private void doRowPermission(QueryStructReq queryStructReq, + AuthorizedResourceResp authorizedResource) { log.debug("start doRowPermission logic"); StringJoiner joiner = new StringJoiner(" OR "); List dimensionFilters = new ArrayList<>(); @@ -232,21 +226,17 @@ public class S2DataPermissionAspect { return; } - dimensionFilters.stream() - .forEach( - filter -> { - if (StringUtils.isNotEmpty(filter) - && StringUtils.isNotEmpty(filter.trim())) { - joiner.add(" ( " + filter + " ) "); - } - }); + dimensionFilters.stream().forEach(filter -> { + if (StringUtils.isNotEmpty(filter) && StringUtils.isNotEmpty(filter.trim())) { + joiner.add(" ( " + filter + " ) "); + } + }); if (StringUtils.isNotEmpty(joiner.toString())) { log.info("before doRowPermission, queryStructReq:{}", queryStructReq); Filter filter = new Filter("", FilterOperatorEnum.SQL_PART, joiner.toString()); List filters = - Objects.isNull(queryStructReq.getOriginalFilter()) - ? new ArrayList<>() + Objects.isNull(queryStructReq.getOriginalFilter()) ? new ArrayList<>() : queryStructReq.getOriginalFilter(); filters.add(filter); queryStructReq.setDimensionFilters(filters); @@ -269,8 +259,7 @@ public class S2DataPermissionAspect { public void checkModelVisible(User user, Set modelIds) { List modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VISIBLE).stream() - .map(ModelResp::getId) - .collect(Collectors.toList()); + .map(ModelResp::getId).collect(Collectors.toList()); List modelIdCopied = new ArrayList<>(modelIds); modelIdCopied.removeAll(modelListVisible); if (!CollectionUtils.isEmpty(modelIdCopied)) { @@ -281,9 +270,8 @@ public class S2DataPermissionAspect { if (modelResp == null) { throw new InvalidArgumentException("查询的模型不存在"); } - String message = - String.format( - "您没有模型[%s]权限,请联系管理员%s开通", modelResp.getName(), modelResp.getAdmins()); + String message = String.format("您没有模型[%s]权限,请联系管理员%s开通", modelResp.getName(), + modelResp.getAdmins()); throw new InvalidPermissionException(message); } } @@ -292,20 +280,14 @@ public class S2DataPermissionAspect { Set highSensitiveCols = new HashSet<>(); if (!CollectionUtils.isEmpty(semanticSchemaResp.getDimensions())) { semanticSchemaResp.getDimensions().stream() - .filter( - dimSchemaResp -> - SensitiveLevelEnum.HIGH - .getCode() - .equals(dimSchemaResp.getSensitiveLevel())) + .filter(dimSchemaResp -> SensitiveLevelEnum.HIGH.getCode() + .equals(dimSchemaResp.getSensitiveLevel())) .forEach(dim -> highSensitiveCols.add(dim.getBizName())); } if (!CollectionUtils.isEmpty(semanticSchemaResp.getMetrics())) { semanticSchemaResp.getMetrics().stream() - .filter( - metricSchemaResp -> - SensitiveLevelEnum.HIGH - .getCode() - .equals(metricSchemaResp.getSensitiveLevel())) + .filter(metricSchemaResp -> SensitiveLevelEnum.HIGH.getCode() + .equals(metricSchemaResp.getSensitiveLevel())) .forEach(metric -> highSensitiveCols.add(metric.getBizName())); } return highSensitiveCols; @@ -315,11 +297,8 @@ public class S2DataPermissionAspect { QueryAuthResReq queryAuthResReq = new QueryAuthResReq(); queryAuthResReq.setModelIds(new ArrayList<>(modelIds)); AuthorizedResourceResp authorizedResource = fetchAuthRes(queryAuthResReq, user); - log.info( - "user:{}, domainId:{}, after queryAuthorizedResources:{}", - user.getName(), - modelIds, - authorizedResource); + log.info("user:{}, domainId:{}, after queryAuthorizedResources:{}", user.getName(), + modelIds, authorizedResource); return authorizedResource; } @@ -328,9 +307,7 @@ public class S2DataPermissionAspect { return authService.queryAuthorizedResources(queryAuthResReq, user); } - public void addHint( - Set modelIds, - SemanticQueryResp queryResultWithColumns, + public void addHint(Set modelIds, SemanticQueryResp queryResultWithColumns, AuthorizedResourceResp authorizedResource) { List filters = authorizedResource.getFilters(); if (CollectionUtils.isEmpty(filters)) { @@ -342,20 +319,15 @@ public class S2DataPermissionAspect { ModelResp modelResp = modelService.getModel(modelIds.iterator().next()); List exprList = new ArrayList<>(); List descList = new ArrayList<>(); - filters.stream() - .forEach( - filter -> { - if (StringUtils.isNotEmpty(filter.getDescription())) { - descList.add(filter.getDescription()); - } - exprList.add(filter.getExpressions().toString()); - }); + filters.stream().forEach(filter -> { + if (StringUtils.isNotEmpty(filter.getDescription())) { + descList.add(filter.getDescription()); + } + exprList.add(filter.getExpressions().toString()); + }); String promptInfo = "当前结果已经过行权限过滤,详细过滤条件如下:%s, 申请权限请联系管理员%s"; - String message = - String.format( - promptInfo, - CollectionUtils.isEmpty(descList) ? exprList : descList, - admins); + String message = String.format(promptInfo, + CollectionUtils.isEmpty(descList) ? exprList : descList, admins); queryResultWithColumns.setQueryAuthorization( new QueryAuthorization(modelResp.getName(), exprList, descList, message)); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/ChatQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/ChatQueryApiController.java index 08533b9bb..7d562ce71 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/ChatQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/ChatQueryApiController.java @@ -23,45 +23,36 @@ import org.springframework.web.bind.annotation.RestController; @Slf4j public class ChatQueryApiController { - @Autowired private ChatLayerService chatLayerService; + @Autowired + private ChatLayerService chatLayerService; - @Autowired private SemanticLayerService semanticLayerService; + @Autowired + private SemanticLayerService semanticLayerService; @PostMapping("/chat/search") - public Object search( - @RequestBody QueryNLReq queryNLReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object search(@RequestBody QueryNLReq queryNLReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { queryNLReq.setUser(UserHolder.findUser(request, response)); return chatLayerService.retrieve(queryNLReq); } @PostMapping("/chat/map") - public Object map( - @RequestBody QueryNLReq queryNLReq, - HttpServletRequest request, + public Object map(@RequestBody QueryNLReq queryNLReq, HttpServletRequest request, HttpServletResponse response) { queryNLReq.setUser(UserHolder.findUser(request, response)); return chatLayerService.performMapping(queryNLReq); } @PostMapping("/chat/parse") - public Object parse( - @RequestBody QueryNLReq queryNLReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object parse(@RequestBody QueryNLReq queryNLReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { queryNLReq.setUser(UserHolder.findUser(request, response)); return chatLayerService.performParsing(queryNLReq); } @PostMapping("/chat") - public Object queryByNL( - @RequestBody QueryNLReq queryNLReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryByNL(@RequestBody QueryNLReq queryNLReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); ParseResp parseResp = chatLayerService.performParsing(queryNLReq); if (parseResp.getState().equals(ParseResp.ParseState.COMPLETED)) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/DataSetQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/DataSetQueryApiController.java index e4b4f325b..e22cd8b19 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/DataSetQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/DataSetQueryApiController.java @@ -21,15 +21,14 @@ import org.springframework.web.bind.annotation.RestController; @Slf4j public class DataSetQueryApiController { - @Autowired private DataSetService dataSetService; - @Autowired private SemanticLayerService semanticLayerService; + @Autowired + private DataSetService dataSetService; + @Autowired + private SemanticLayerService semanticLayerService; @PostMapping("/dataSet") - public Object queryByDataSet( - @RequestBody QueryDataSetReq queryDataSetReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryByDataSet(@RequestBody QueryDataSetReq queryDataSetReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); SemanticQueryReq queryReq = dataSetService.convert(queryDataSetReq); return semanticLayerService.queryByReq(queryReq, user); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetaDiscoveryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetaDiscoveryApiController.java index 1384a5973..a426ade41 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetaDiscoveryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetaDiscoveryApiController.java @@ -19,14 +19,12 @@ import org.springframework.web.bind.annotation.RestController; @Slf4j public class MetaDiscoveryApiController { - @Autowired private ChatLayerService chatLayerService; + @Autowired + private ChatLayerService chatLayerService; @PostMapping("map") - public Object map( - @RequestBody QueryMapReq queryMapReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object map(@RequestBody QueryMapReq queryMapReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); queryMapReq.setUser(user); return chatLayerService.map(queryMapReq); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetricQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetricQueryApiController.java index d95414572..8da3b719c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetricQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/MetricQueryApiController.java @@ -24,39 +24,33 @@ import org.springframework.web.bind.annotation.RestController; @Slf4j public class MetricQueryApiController { - @Autowired private SemanticLayerService semanticLayerService; + @Autowired + private SemanticLayerService semanticLayerService; - @Autowired private MetricService metricService; + @Autowired + private MetricService metricService; - @Autowired private DownloadService downloadService; + @Autowired + private DownloadService downloadService; @PostMapping("/metric") - public Object queryByMetric( - @RequestBody QueryMetricReq queryMetricReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryByMetric(@RequestBody QueryMetricReq queryMetricReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); QueryStructReq queryStructReq = metricService.convert(queryMetricReq); return semanticLayerService.queryByReq(queryStructReq.convert(true), user); } @PostMapping("/download/metric") - public void downloadMetric( - @RequestBody DownloadMetricReq downloadMetricReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public void downloadMetric(@RequestBody DownloadMetricReq downloadMetricReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); downloadService.downloadByStruct(downloadMetricReq, user, response); } @PostMapping("/downloadBatch/metric") - public void downloadBatch( - @RequestBody BatchDownloadReq batchDownloadReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public void downloadBatch(@RequestBody BatchDownloadReq batchDownloadReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); downloadService.batchDownload(batchDownloadReq, user, response); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java index 7dd387445..7d4056120 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/SqlQueryApiController.java @@ -30,16 +30,15 @@ import java.util.stream.Collectors; @Slf4j public class SqlQueryApiController { - @Autowired private SemanticLayerService semanticLayerService; + @Autowired + private SemanticLayerService semanticLayerService; - @Autowired private ChatLayerService chatLayerService; + @Autowired + private ChatLayerService chatLayerService; @PostMapping("/sql") - public Object queryBySql( - @RequestBody QuerySqlReq querySqlReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryBySql(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); @@ -48,63 +47,40 @@ public class SqlQueryApiController { } @PostMapping("/sqls") - public Object queryBySqls( - @RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryBySqls(@RequestBody QuerySqlsReq querySqlsReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); - List semanticQueryReqs = - querySqlsReq.getSqls().stream() - .map( - sql -> { - QuerySqlReq querySqlReq = new QuerySqlReq(); - BeanUtils.copyProperties(querySqlsReq, querySqlReq); - querySqlReq.setSql(StringUtil.replaceBackticks(sql)); - chatLayerService.correct(querySqlReq, user); - return querySqlReq; - }) - .collect(Collectors.toList()); + List semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> { + QuerySqlReq querySqlReq = new QuerySqlReq(); + BeanUtils.copyProperties(querySqlsReq, querySqlReq); + querySqlReq.setSql(StringUtil.replaceBackticks(sql)); + chatLayerService.correct(querySqlReq, user); + return querySqlReq; + }).collect(Collectors.toList()); List> futures = - semanticQueryReqs.stream() - .map( - querySqlReq -> - CompletableFuture.supplyAsync( - () -> { - try { - return semanticLayerService.queryByReq( - querySqlReq, user); - } catch (Exception e) { - log.error( - "querySqlReq:{},queryByReq error:", - querySqlReq, - e); - return new SemanticQueryResp(); - } - })) - .collect(Collectors.toList()); + semanticQueryReqs.stream().map(querySqlReq -> CompletableFuture.supplyAsync(() -> { + try { + return semanticLayerService.queryByReq(querySqlReq, user); + } catch (Exception e) { + log.error("querySqlReq:{},queryByReq error:", querySqlReq, e); + return new SemanticQueryResp(); + } + })).collect(Collectors.toList()); return futures.stream().map(CompletableFuture::join).collect(Collectors.toList()); } @PostMapping("/sqlsWithException") - public Object queryBySqlsWithException( - @RequestBody QuerySqlsReq querySqlsReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryBySqlsWithException(@RequestBody QuerySqlsReq querySqlsReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); - List semanticQueryReqs = - querySqlsReq.getSqls().stream() - .map( - sql -> { - QuerySqlReq querySqlReq = new QuerySqlReq(); - BeanUtils.copyProperties(querySqlsReq, querySqlReq); - querySqlReq.setSql(StringUtil.replaceBackticks(sql)); - chatLayerService.correct(querySqlReq, user); - return querySqlReq; - }) - .collect(Collectors.toList()); + List semanticQueryReqs = querySqlsReq.getSqls().stream().map(sql -> { + QuerySqlReq querySqlReq = new QuerySqlReq(); + BeanUtils.copyProperties(querySqlsReq, querySqlReq); + querySqlReq.setSql(StringUtil.replaceBackticks(sql)); + chatLayerService.correct(querySqlReq, user); + return querySqlReq; + }).collect(Collectors.toList()); List semanticQueryRespList = new ArrayList<>(); try { for (SemanticQueryReq semanticQueryReq : semanticQueryReqs) { @@ -119,11 +95,8 @@ public class SqlQueryApiController { } @PostMapping("/validate") - public Object validate( - @RequestBody QuerySqlReq querySqlReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object validate(@RequestBody QuerySqlReq querySqlReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); String sql = querySqlReq.getSql(); querySqlReq.setSql(StringUtil.replaceBackticks(sql)); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/TagQueryApiController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/TagQueryApiController.java index 67ce37d65..8973667a0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/TagQueryApiController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/rest/TagQueryApiController.java @@ -19,14 +19,12 @@ import org.springframework.web.bind.annotation.RestController; @Slf4j public class TagQueryApiController { - @Autowired private SemanticLayerService semanticLayerService; + @Autowired + private SemanticLayerService semanticLayerService; @PostMapping("/tag") - public Object queryByTag( - @RequestBody QueryStructReq queryStructReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Object queryByTag(@RequestBody QueryStructReq queryStructReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return semanticLayerService.queryByReq(queryStructReq.convert(), user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java index a3682bdf8..06b395b35 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2ChatLayerService.java @@ -53,20 +53,22 @@ import java.util.stream.Collectors; @Service @Slf4j public class S2ChatLayerService implements ChatLayerService { - @Autowired private SchemaService schemaService; - @Autowired private DataSetService dataSetService; - @Autowired private RetrieveService retrieveService; - @Autowired private ChatWorkflowEngine chatWorkflowEngine; + @Autowired + private SchemaService schemaService; + @Autowired + private DataSetService dataSetService; + @Autowired + private RetrieveService retrieveService; + @Autowired + private ChatWorkflowEngine chatWorkflowEngine; @Override public MapResp performMapping(QueryNLReq queryNLReq) { MapResp mapResp = new MapResp(queryNLReq.getQueryText()); ChatQueryContext queryCtx = buildChatQueryContext(queryNLReq); - ComponentFactory.getSchemaMappers() - .forEach( - mapper -> { - mapper.map(queryCtx); - }); + ComponentFactory.getSchemaMappers().forEach(mapper -> { + mapper.map(queryCtx); + }); mapResp.setMapInfo(queryCtx.getMapInfo()); return mapResp; } @@ -96,17 +98,12 @@ public class S2ChatLayerService implements ChatLayerService { private ChatQueryContext buildChatQueryContext(QueryNLReq queryNLReq) { SemanticSchema semanticSchema = schemaService.getSemanticSchema(queryNLReq.getDataSetIds()); Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds(); - ChatQueryContext queryCtx = - ChatQueryContext.builder() - .queryFilters(queryNLReq.getQueryFilters()) - .semanticSchema(semanticSchema) - .candidateQueries(new ArrayList<>()) - .mapInfo(new SchemaMapInfo()) - .modelIdToDataSetIds(modelIdToDataSetIds) - .text2SQLType(queryNLReq.getText2SQLType()) - .mapModeEnum(queryNLReq.getMapModeEnum()) - .dataSetIds(queryNLReq.getDataSetIds()) - .build(); + ChatQueryContext queryCtx = ChatQueryContext.builder() + .queryFilters(queryNLReq.getQueryFilters()).semanticSchema(semanticSchema) + .candidateQueries(new ArrayList<>()).mapInfo(new SchemaMapInfo()) + .modelIdToDataSetIds(modelIdToDataSetIds).text2SQLType(queryNLReq.getText2SQLType()) + .mapModeEnum(queryNLReq.getMapModeEnum()).dataSetIds(queryNLReq.getDataSetIds()) + .build(); BeanUtils.copyProperties(queryNLReq, queryCtx); return queryCtx; } @@ -146,14 +143,12 @@ public class S2ChatLayerService implements ChatLayerService { SchemaElement dataSet = semanticSchema.getDataSet(dataSetId); semanticParseInfo.setDataSet(dataSet); semanticParseInfo.setQueryConfig(semanticSchema.getQueryConfig(dataSetId)); - ComponentFactory.getSemanticCorrectors() - .forEach( - corrector -> { - if (!(corrector instanceof GrammarCorrector - || (corrector instanceof SchemaCorrector))) { - corrector.correct(queryCtx, semanticParseInfo); - } - }); + ComponentFactory.getSemanticCorrectors().forEach(corrector -> { + if (!(corrector instanceof GrammarCorrector + || (corrector instanceof SchemaCorrector))) { + corrector.correct(queryCtx, semanticParseInfo); + } + }); log.info("chatQueryServiceImpl correct:{}", sqlInfo.getCorrectedS2SQL()); return semanticParseInfo; } @@ -174,8 +169,8 @@ public class S2ChatLayerService implements ChatLayerService { return mapInfoResp; } - private Map getDataSetInfo( - SchemaMapInfo mapInfo, Map dataSetMap, Integer topN) { + private Map getDataSetInfo(SchemaMapInfo mapInfo, + Map dataSetMap, Integer topN) { Map map = new HashMap<>(); Map> mapFields = getMapFields(mapInfo, dataSetMap); Map> topFields = getTopFields(topN, mapInfo, dataSetMap); @@ -197,18 +192,15 @@ public class S2ChatLayerService implements ChatLayerService { return map; } - private Map> getMapFields( - SchemaMapInfo mapInfo, Map dataSetMap) { + private Map> getMapFields(SchemaMapInfo mapInfo, + Map dataSetMap) { Map> result = new HashMap<>(); - for (Map.Entry> entry : - mapInfo.getDataSetElementMatches().entrySet()) { - List values = - entry.getValue().stream() - .filter( - schemaElementMatch -> - !SchemaElementType.TERM.equals( - schemaElementMatch.getElement().getType())) - .collect(Collectors.toList()); + for (Map.Entry> entry : mapInfo.getDataSetElementMatches() + .entrySet()) { + List values = entry.getValue().stream() + .filter(schemaElementMatch -> !SchemaElementType.TERM + .equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toList()); if (CollectionUtils.isNotEmpty(values) && dataSetMap.containsKey(entry.getKey())) { result.put(entry.getKey(), values); } @@ -216,15 +208,15 @@ public class S2ChatLayerService implements ChatLayerService { return result; } - private Map> getTopFields( - Integer topN, SchemaMapInfo mapInfo, Map dataSetMap) { + private Map> getTopFields(Integer topN, SchemaMapInfo mapInfo, + Map dataSetMap) { Map> result = new HashMap<>(); if (0 == topN) { return result; } SemanticSchema semanticSchema = schemaService.getSemanticSchema(); - for (Map.Entry> entry : - mapInfo.getDataSetElementMatches().entrySet()) { + for (Map.Entry> entry : mapInfo.getDataSetElementMatches() + .entrySet()) { Long dataSetId = entry.getKey(); List values = entry.getValue(); DataSetResp dataSetResp = dataSetMap.get(dataSetId); @@ -233,23 +225,17 @@ public class S2ChatLayerService implements ChatLayerService { } String dataSetName = dataSetResp.getName(); // topN dimensions - Set dimensions = - semanticSchema.getDimensions(dataSetId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(topN - 1) - .map(mergeFunction()) - .collect(Collectors.toSet()); + Set dimensions = semanticSchema.getDimensions(dataSetId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(topN - 1).map(mergeFunction()).collect(Collectors.toSet()); SchemaElementMatch timeDimensionMatch = getTimeDimension(dataSetId, dataSetName); dimensions.add(timeDimensionMatch); // topN metrics - Set metrics = - semanticSchema.getMetrics(dataSetId).stream() - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(topN) - .map(mergeFunction()) - .collect(Collectors.toSet()); + Set metrics = semanticSchema.getMetrics(dataSetId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()).limit(topN) + .map(mergeFunction()).collect(Collectors.toSet()); dimensions.addAll(metrics); result.put(dataSetId, new ArrayList<>(dimensions)); @@ -257,8 +243,8 @@ public class S2ChatLayerService implements ChatLayerService { return result; } - private Map> getTerms( - SchemaMapInfo mapInfo, Map dataSetNameMap) { + private Map> getTerms(SchemaMapInfo mapInfo, + Map dataSetNameMap) { Map> termMap = new HashMap<>(); Map> dataSetElementMatches = mapInfo.getDataSetElementMatches(); @@ -267,13 +253,10 @@ public class S2ChatLayerService implements ChatLayerService { if (dataSetResp == null) { continue; } - List terms = - entry.getValue().stream() - .filter( - schemaElementMatch -> - SchemaElementType.TERM.equals( - schemaElementMatch.getElement().getType())) - .collect(Collectors.toList()); + List terms = entry.getValue().stream() + .filter(schemaElementMatch -> SchemaElementType.TERM + .equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toList()); termMap.put(dataSetResp.getName(), terms); } return termMap; @@ -287,34 +270,21 @@ public class S2ChatLayerService implements ChatLayerService { * @return */ private SchemaElementMatch getTimeDimension(Long dataSetId, String dataSetName) { - SchemaElement element = - SchemaElement.builder() - .dataSetId(dataSetId) - .dataSetName(dataSetName) - .type(SchemaElementType.DIMENSION) - .bizName(TimeDimensionEnum.DAY.getName()) - .build(); + SchemaElement element = SchemaElement.builder().dataSetId(dataSetId) + .dataSetName(dataSetName).type(SchemaElementType.DIMENSION) + .bizName(TimeDimensionEnum.DAY.getName()).build(); - SchemaElementMatch timeDimensionMatch = - SchemaElementMatch.builder() - .element(element) - .detectWord(TimeDimensionEnum.DAY.getChName()) - .word(TimeDimensionEnum.DAY.getChName()) - .similarity(1L) - .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) - .build(); + SchemaElementMatch timeDimensionMatch = SchemaElementMatch.builder().element(element) + .detectWord(TimeDimensionEnum.DAY.getChName()) + .word(TimeDimensionEnum.DAY.getChName()).similarity(1L) + .frequency(BaseWordBuilder.DEFAULT_FREQUENCY).build(); return timeDimensionMatch; } private Function mergeFunction() { - return schemaElement -> - SchemaElementMatch.builder() - .element(schemaElement) - .frequency(BaseWordBuilder.DEFAULT_FREQUENCY) - .word(schemaElement.getName()) - .similarity(1) - .detectWord(schemaElement.getName()) - .build(); + return schemaElement -> SchemaElementMatch.builder().element(schemaElement) + .frequency(BaseWordBuilder.DEFAULT_FREQUENCY).word(schemaElement.getName()) + .similarity(1).detectWord(schemaElement.getName()).build(); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java index 56af63059..358047542 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/facade/service/impl/S2SemanticLayerService.java @@ -85,17 +85,11 @@ public class S2SemanticLayerService implements SemanticLayerService { private final QueryCache queryCache = ComponentFactory.getQueryCache(); private final List queryExecutors = ComponentFactory.getQueryExecutors(); - public S2SemanticLayerService( - StatUtils statUtils, - QueryUtils queryUtils, - QueryReqConverter queryReqConverter, - SemanticSchemaManager semanticSchemaManager, - DataSetService dataSetService, - SchemaService schemaService, - SemanticTranslator semanticTranslator, - MetricDrillDownChecker metricDrillDownChecker, - KnowledgeBaseService knowledgeBaseService, - MetricService metricService, + public S2SemanticLayerService(StatUtils statUtils, QueryUtils queryUtils, + QueryReqConverter queryReqConverter, SemanticSchemaManager semanticSchemaManager, + DataSetService dataSetService, SchemaService schemaService, + SemanticTranslator semanticTranslator, MetricDrillDownChecker metricDrillDownChecker, + KnowledgeBaseService knowledgeBaseService, MetricService metricService, DimensionService dimensionService) { this.statUtils = statUtils; this.queryUtils = queryUtils; @@ -119,11 +113,8 @@ public class S2SemanticLayerService implements SemanticLayerService { public SemanticTranslateResp translate(SemanticQueryReq queryReq, User user) throws Exception { QueryStatement queryStatement = buildQueryStatement(queryReq, user); semanticTranslator.translate(queryStatement); - return SemanticTranslateResp.builder() - .querySQL(queryStatement.getSql()) - .isOk(queryStatement.isOk()) - .errMsg(queryStatement.getErrMsg()) - .build(); + return SemanticTranslateResp.builder().querySQL(queryStatement.getSql()) + .isOk(queryStatement.isOk()).errMsg(queryStatement.getErrMsg()).build(); } @Override @@ -161,8 +152,8 @@ public class S2SemanticLayerService implements SemanticLayerService { for (QueryExecutor queryExecutor : queryExecutors) { if (queryExecutor.accept(queryStatement)) { queryResp = queryExecutor.execute(queryStatement); - queryUtils.populateQueryColumns( - queryResp, queryStatement.getSemanticSchemaResp()); + queryUtils.populateQueryColumns(queryResp, + queryStatement.getSemanticSchemaResp()); } } @@ -208,8 +199,8 @@ public class S2SemanticLayerService implements SemanticLayerService { return semanticQueryResp; } - private List getDimensionValuesFromDict( - DimensionValueReq dimensionValueReq, Set dataSetIds) { + private List getDimensionValuesFromDict(DimensionValueReq dimensionValueReq, + Set dataSetIds) { if (StringUtils.isBlank(dimensionValueReq.getValue())) { return SearchService.getDimensionValue(dimensionValueReq); } @@ -217,28 +208,19 @@ public class S2SemanticLayerService implements SemanticLayerService { Map> modelIdToDataSetIds = new HashMap<>(); modelIdToDataSetIds.put(dimensionValueReq.getModelId(), new ArrayList<>(dataSetIds)); - List hanlpMapResultList = - knowledgeBaseService.prefixSearch( - dimensionValueReq.getValue(), 2000, modelIdToDataSetIds, dataSetIds); + List hanlpMapResultList = knowledgeBaseService + .prefixSearch(dimensionValueReq.getValue(), 2000, modelIdToDataSetIds, dataSetIds); HanlpHelper.transLetterOriginal(hanlpMapResultList); return hanlpMapResultList.stream() - .filter( - o -> - o.getNatures().stream() - .map(NatureHelper::getElementID) - .anyMatch( - elementID -> - dimensionValueReq - .getElementID() - .equals(elementID))) - .map(MapResult::getName) - .collect(Collectors.toList()); + .filter(o -> o.getNatures().stream().map(NatureHelper::getElementID) + .anyMatch(elementID -> dimensionValueReq.getElementID().equals(elementID))) + .map(MapResult::getName).collect(Collectors.toList()); } - private SemanticQueryResp getDimensionValuesFromDb( - DimensionValueReq dimensionValueReq, User user) { + private SemanticQueryResp getDimensionValuesFromDb(DimensionValueReq dimensionValueReq, + User user) { QuerySqlReq querySqlReq = buildQuerySqlReq(dimensionValueReq); return queryByReq(querySqlReq, user); } @@ -255,16 +237,13 @@ public class S2SemanticLayerService implements SemanticLayerService { return columns; } - private List> createResultList( - DimensionValueReq dimensionValueReq, List dimensionValues) { - return dimensionValues.stream() - .map( - value -> { - Map map = new HashMap<>(); - map.put(dimensionValueReq.getBizName(), value); - return map; - }) - .collect(Collectors.toList()); + private List> createResultList(DimensionValueReq dimensionValueReq, + List dimensionValues) { + return dimensionValues.stream().map(value -> { + Map map = new HashMap<>(); + map.put(dimensionValueReq.getBizName(), value); + return map; + }).collect(Collectors.toList()); } private DimensionResp getDimension(DimensionValueReq dimensionValueReq) { @@ -278,8 +257,8 @@ public class S2SemanticLayerService implements SemanticLayerService { return dimensionResp; } - public EntityInfo getEntityInfo( - SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, User user) { + public EntityInfo getEntityInfo(SemanticParseInfo parseInfo, DataSetSchema dataSetSchema, + User user) { if (parseInfo != null && parseInfo.getDataSetId() != null && parseInfo.getDataSetId() > 0) { EntityInfo entityInfo = getEntityBasicInfo(dataSetSchema); if (parseInfo.getDimensionFilters().size() <= 0 @@ -292,8 +271,7 @@ public class S2SemanticLayerService implements SemanticLayerService { if (StringUtils.isNotBlank(primaryKey)) { String entityId = ""; for (QueryFilter chatFilter : parseInfo.getDimensionFilters()) { - if (chatFilter != null - && chatFilter.getBizName() != null + if (chatFilter != null && chatFilter.getBizName() != null && chatFilter.getBizName().equals(primaryKey)) { if (chatFilter.getOperator().equals(FilterOperatorEnum.EQUALS)) { entityId = chatFilter.getValue().toString(); @@ -377,8 +355,7 @@ public class S2SemanticLayerService implements SemanticLayerService { queryStatement = buildMultiStructQueryStatement((QueryMultiStructReq) semanticQueryReq, user); } - if (Objects.nonNull(queryStatement) - && Objects.nonNull(semanticQueryReq.getSqlInfo()) + if (Objects.nonNull(queryStatement) && Objects.nonNull(semanticQueryReq.getSqlInfo()) && StringUtils.isNotBlank(semanticQueryReq.getSqlInfo().getQuerySQL())) { queryStatement.setSql(semanticQueryReq.getSqlInfo().getQuerySQL()); queryStatement.setDataSetId(semanticQueryReq.getDataSetId()); @@ -402,8 +379,8 @@ public class S2SemanticLayerService implements SemanticLayerService { return queryStatement; } - private QueryStatement buildMultiStructQueryStatement( - QueryMultiStructReq queryMultiStructReq, User user) throws Exception { + private QueryStatement buildMultiStructQueryStatement(QueryMultiStructReq queryMultiStructReq, + User user) throws Exception { List sqlParsers = new ArrayList<>(); for (QueryStructReq queryStructReq : queryMultiStructReq.getQueryStructReqs()) { QueryStatement queryStatement = buildQueryStatement(queryStructReq, user); @@ -429,32 +406,20 @@ public class S2SemanticLayerService implements SemanticLayerService { QuerySqlReq querySqlReq = new QuerySqlReq(); List modelResps = schemaService.getModelList(Lists.newArrayList(queryDimValueReq.getModelId())); - DimensionResp dimensionResp = - schemaService.getDimension( - queryDimValueReq.getBizName(), queryDimValueReq.getModelId()); + DimensionResp dimensionResp = schemaService.getDimension(queryDimValueReq.getBizName(), + queryDimValueReq.getModelId()); ModelResp modelResp = modelResps.get(0); - String sql = - String.format( - "select distinct %s from %s where 1=1", - dimensionResp.getName(), modelResp.getName()); + String sql = String.format("select distinct %s from %s where 1=1", dimensionResp.getName(), + modelResp.getName()); List timeDims = modelResp.getTimeDimension(); if (CollectionUtils.isNotEmpty(timeDims)) { - sql = - String.format( - "%s and %s >= '%s' and %s <= '%s'", - sql, - TimeDimensionEnum.DAY.getName(), - queryDimValueReq.getDateInfo().getStartDate(), - TimeDimensionEnum.DAY.getName(), - queryDimValueReq.getDateInfo().getEndDate()); + sql = String.format("%s and %s >= '%s' and %s <= '%s'", sql, + TimeDimensionEnum.DAY.getName(), queryDimValueReq.getDateInfo().getStartDate(), + TimeDimensionEnum.DAY.getName(), queryDimValueReq.getDateInfo().getEndDate()); } if (StringUtils.isNotBlank(queryDimValueReq.getValue())) { - sql += - " AND " - + queryDimValueReq.getBizName() - + " LIKE '%" - + queryDimValueReq.getValue() - + "%'"; + sql += " AND " + queryDimValueReq.getBizName() + " LIKE '%" + + queryDimValueReq.getValue() + "%'"; } querySqlReq.setModelIds(Sets.newHashSet(queryDimValueReq.getModelId())); querySqlReq.setSql(sql); @@ -488,48 +453,32 @@ public class S2SemanticLayerService implements SemanticLayerService { || detailTypeDefaultConfig.getDefaultDisplayInfo() == null) { return entityInfo; } - List dimensions = - detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream() - .map( - id -> { - SchemaElement element = - dataSetSchema.getElement( - SchemaElementType.DIMENSION, id); - if (element == null) { - return null; - } - return new DataInfo( - element.getId().intValue(), - element.getName(), - element.getBizName(), - null); - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); - List metrics = - detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds().stream() - .map( - id -> { - SchemaElement element = - dataSetSchema.getElement(SchemaElementType.METRIC, id); - if (element == null) { - return null; - } - return new DataInfo( - element.getId().intValue(), - element.getName(), - element.getBizName(), - null); - }) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + List dimensions = detailTypeDefaultConfig.getDefaultDisplayInfo() + .getDimensionIds().stream().map(id -> { + SchemaElement element = + dataSetSchema.getElement(SchemaElementType.DIMENSION, id); + if (element == null) { + return null; + } + return new DataInfo(element.getId().intValue(), element.getName(), + element.getBizName(), null); + }).filter(Objects::nonNull).collect(Collectors.toList()); + List metrics = detailTypeDefaultConfig.getDefaultDisplayInfo().getDimensionIds() + .stream().map(id -> { + SchemaElement element = dataSetSchema.getElement(SchemaElementType.METRIC, id); + if (element == null) { + return null; + } + return new DataInfo(element.getId().intValue(), element.getName(), + element.getBizName(), null); + }).filter(Objects::nonNull).collect(Collectors.toList()); entityInfo.setDimensions(dimensions); entityInfo.setMetrics(metrics); return entityInfo; } - private void fillEntityInfoValue( - EntityInfo entityInfo, DataSetSchema dataSetSchema, User user) { + private void fillEntityInfoValue(EntityInfo entityInfo, DataSetSchema dataSetSchema, + User user) { SemanticQueryResp queryResultWithColumns = getQueryResultWithSchemaResp(entityInfo, dataSetSchema, user); if (queryResultWithColumns != null) { @@ -540,19 +489,17 @@ public class S2SemanticLayerService implements SemanticLayerService { if (entry.getValue() == null || entryKey == null) { continue; } - entityInfo.getDimensions().stream() - .filter(i -> entryKey.equals(i.getBizName())) + entityInfo.getDimensions().stream().filter(i -> entryKey.equals(i.getBizName())) .forEach(i -> i.setValue(entry.getValue().toString())); - entityInfo.getMetrics().stream() - .filter(i -> entryKey.equals(i.getBizName())) + entityInfo.getMetrics().stream().filter(i -> entryKey.equals(i.getBizName())) .forEach(i -> i.setValue(entry.getValue().toString())); } } } } - private SemanticQueryResp getQueryResultWithSchemaResp( - EntityInfo entityInfo, DataSetSchema dataSetSchema, User user) { + private SemanticQueryResp getQueryResultWithSchemaResp(EntityInfo entityInfo, + DataSetSchema dataSetSchema, User user) { SemanticParseInfo semanticParseInfo = new SemanticParseInfo(); semanticParseInfo.setDataSet(dataSetSchema.getDataSet()); semanticParseInfo.setQueryType(QueryType.DETAIL); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java index 7269e5c1d..0133f8df9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/MetaEmbeddingListener.java @@ -21,9 +21,11 @@ import java.util.List; @Slf4j public class MetaEmbeddingListener implements ApplicationListener { - @Autowired private EmbeddingConfig embeddingConfig; + @Autowired + private EmbeddingConfig embeddingConfig; - @Autowired private EmbeddingService embeddingService; + @Autowired + private EmbeddingService embeddingService; @Value("${s2.embedding.operation.sleep.time:3000}") private Integer embeddingOperationSleepTime; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SchemaDictUpdateListener.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SchemaDictUpdateListener.java index f2a7c3c04..6561821e0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SchemaDictUpdateListener.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/listener/SchemaDictUpdateListener.java @@ -22,28 +22,24 @@ public class SchemaDictUpdateListener implements ApplicationListener if (CollectionUtils.isEmpty(dataEvent.getDataItems())) { return; } - dataEvent - .getDataItems() - .forEach( - dataItem -> { - DictWord dictWord = new DictWord(); - dictWord.setWord(dataItem.getName()); - String sign = DictWordType.NATURE_SPILT; - String suffixNature = DictWordType.getSuffixNature(dataItem.getType()); - String nature = - sign + dataItem.getModelId() + dataItem.getId() + suffixNature; - String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY; - dictWord.setNature(nature); - dictWord.setNatureWithFrequency(natureWithFrequency); - if (EventType.ADD.equals(dataEvent.getEventType())) { - HanlpHelper.addToCustomDictionary(dictWord); - } else if (EventType.DELETE.equals(dataEvent.getEventType())) { - HanlpHelper.removeFromCustomDictionary(dictWord); - } else if (EventType.UPDATE.equals(dataEvent.getEventType())) { - HanlpHelper.removeFromCustomDictionary(dictWord); - dictWord.setWord(dataItem.getNewName()); - HanlpHelper.addToCustomDictionary(dictWord); - } - }); + dataEvent.getDataItems().forEach(dataItem -> { + DictWord dictWord = new DictWord(); + dictWord.setWord(dataItem.getName()); + String sign = DictWordType.NATURE_SPILT; + String suffixNature = DictWordType.getSuffixNature(dataItem.getType()); + String nature = sign + dataItem.getModelId() + dataItem.getId() + suffixNature; + String natureWithFrequency = nature + " " + Constants.DEFAULT_FREQUENCY; + dictWord.setNature(nature); + dictWord.setNatureWithFrequency(natureWithFrequency); + if (EventType.ADD.equals(dataEvent.getEventType())) { + HanlpHelper.addToCustomDictionary(dictWord); + } else if (EventType.DELETE.equals(dataEvent.getEventType())) { + HanlpHelper.removeFromCustomDictionary(dictWord); + } else if (EventType.UPDATE.equals(dataEvent.getEventType())) { + HanlpHelper.removeFromCustomDictionary(dictWord); + dictWord.setWord(dataItem.getNewName()); + HanlpHelper.addToCustomDictionary(dictWord); + } + }); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/DimensionYamlManager.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/DimensionYamlManager.java index 630ab6a96..ae76ef252 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/DimensionYamlManager.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/DimensionYamlManager.java @@ -22,14 +22,9 @@ public class DimensionYamlManager { return new ArrayList<>(); } return dimensions.stream() - .filter( - dimension -> - !dimension - .getType() - .name() - .equalsIgnoreCase(IdentifyType.primary.name())) - .map(DimensionYamlManager::convert2DimensionYamlTpl) - .collect(Collectors.toList()); + .filter(dimension -> !dimension.getType().name() + .equalsIgnoreCase(IdentifyType.primary.name())) + .map(DimensionYamlManager::convert2DimensionYamlTpl).collect(Collectors.toList()); } public static DimensionYamlTpl convert2DimensionYamlTpl(DimensionResp dimension) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java index 6ab181f5e..3a96b718f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/ModelYamlManager.java @@ -28,8 +28,8 @@ import java.util.stream.Collectors; @Slf4j public class ModelYamlManager { - public static synchronized DataModelYamlTpl convert2YamlObj( - ModelResp modelResp, DatabaseResp databaseResp) { + public static synchronized DataModelYamlTpl convert2YamlObj(ModelResp modelResp, + DatabaseResp databaseResp) { ModelDetail modelDetail = modelResp.getModelDetail(); DbAdaptor engineAdaptor = DbAdaptorFactory.getEngineAdaptor(databaseResp.getType()); SysTimeDimensionBuilder.addSysTimeDimension(modelDetail.getDimensions(), engineAdaptor); @@ -37,18 +37,12 @@ public class ModelYamlManager { DataModelYamlTpl dataModelYamlTpl = new DataModelYamlTpl(); dataModelYamlTpl.setType(databaseResp.getType()); BeanUtils.copyProperties(modelDetail, dataModelYamlTpl); - dataModelYamlTpl.setIdentifiers( - modelDetail.getIdentifiers().stream() - .map(ModelYamlManager::convert) - .collect(Collectors.toList())); - dataModelYamlTpl.setDimensions( - modelDetail.getDimensions().stream() - .map(ModelYamlManager::convert) - .collect(Collectors.toList())); - dataModelYamlTpl.setMeasures( - modelDetail.getMeasures().stream() - .map(ModelYamlManager::convert) - .collect(Collectors.toList())); + dataModelYamlTpl.setIdentifiers(modelDetail.getIdentifiers().stream() + .map(ModelYamlManager::convert).collect(Collectors.toList())); + dataModelYamlTpl.setDimensions(modelDetail.getDimensions().stream() + .map(ModelYamlManager::convert).collect(Collectors.toList())); + dataModelYamlTpl.setMeasures(modelDetail.getMeasures().stream() + .map(ModelYamlManager::convert).collect(Collectors.toList())); dataModelYamlTpl.setName(modelResp.getBizName()); dataModelYamlTpl.setSourceId(modelResp.getDatabaseId()); if (modelDetail.getQueryType().equalsIgnoreCase(ModelDefineType.SQL_QUERY.getName())) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java index d909da451..628a3b79d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/manager/SemanticSchemaManager.java @@ -64,12 +64,8 @@ public class SemanticSchemaManager { List dataModelYamlTpls = new ArrayList<>(); List metricYamlTpls = new ArrayList<>(); Map modelIdName = new HashMap<>(); - schemaService.getSchemaYamlTpl( - semanticSchemaResp, - dimensionYamlTpls, - dataModelYamlTpls, - metricYamlTpls, - modelIdName); + schemaService.getSchemaYamlTpl(semanticSchemaResp, dimensionYamlTpls, dataModelYamlTpls, + metricYamlTpls, modelIdName); DatabaseResp databaseResp = semanticSchemaResp.getDatabaseResp(); semanticModel.setDatabase(DatabaseConverter.convert(databaseResp)); if (!CollectionUtils.isEmpty(semanticSchemaResp.getModelRelas())) { @@ -78,11 +74,8 @@ public class SemanticSchemaManager { } if (!dataModelYamlTpls.isEmpty()) { Map dataSourceMap = - dataModelYamlTpls.stream() - .map(SemanticSchemaManager::getDatasource) - .collect( - Collectors.toMap( - DataSource::getName, item -> item, (k1, k2) -> k1)); + dataModelYamlTpls.stream().map(SemanticSchemaManager::getDatasource).collect( + Collectors.toMap(DataSource::getName, item -> item, (k1, k2) -> k1)); semanticModel.setDatasourceMap(dataSourceMap); } if (!dimensionYamlTpls.isEmpty()) { @@ -114,8 +107,8 @@ public class SemanticSchemaManager { } if (Objects.nonNull(semanticModel.getDatasourceMap()) && !semanticModel.getDatasourceMap().isEmpty()) { - for (Map.Entry entry : - semanticModel.getDatasourceMap().entrySet()) { + for (Map.Entry entry : semanticModel.getDatasourceMap() + .entrySet()) { List modelDimensions = new ArrayList<>(); if (!semanticModel.getDimensionMap().containsKey(entry.getKey())) { semanticModel.getDimensionMap().put(entry.getKey(), modelDimensions); @@ -133,18 +126,16 @@ public class SemanticSchemaManager { return semanticModel; } - private void addTagModel( - TagResp tagResp, List modelDimensions, List modelMetrics) - throws Exception { + private void addTagModel(TagResp tagResp, List modelDimensions, + List modelMetrics) throws Exception { TagDefineType tagDefineType = TagDefineType.valueOf(tagResp.getTagDefineType()); switch (tagDefineType) { case FIELD: case DIMENSION: if (TagDefineType.DIMENSION.equals(tagResp.getTagDefineType())) { - Optional modelDimension = - modelDimensions.stream() - // .filter(d -> d.getBizName().equals(tagResp.getExpr())) - .findFirst(); + Optional modelDimension = modelDimensions.stream() + // .filter(d -> d.getBizName().equals(tagResp.getExpr())) + .findFirst(); if (modelDimension.isPresent()) { modelDimension.get().setName(tagResp.getBizName()); return; @@ -152,7 +143,7 @@ public class SemanticSchemaManager { } Dimension dimension = Dimension.builder().build(); dimension.setType(""); - // dimension.setExpr(tagResp.getExpr()); + // dimension.setExpr(tagResp.getExpr()); dimension.setName(tagResp.getBizName()); dimension.setOwners(""); dimension.setBizName(tagResp.getBizName()); @@ -165,10 +156,9 @@ public class SemanticSchemaManager { modelDimensions.add(dimension); return; case METRIC: - Optional modelMetric = - modelMetrics.stream() - // .filter(m -> m.getName().equalsIgnoreCase(tagResp.getExpr())) - .findFirst(); + Optional modelMetric = modelMetrics.stream() + // .filter(m -> m.getName().equalsIgnoreCase(tagResp.getExpr())) + .findFirst(); if (modelMetric.isPresent()) { modelMetric.get().setName(tagResp.getBizName()); } else { @@ -189,37 +179,22 @@ public class SemanticSchemaManager { } public static DataSource getDatasource(final DataModelYamlTpl d) { - DataSource datasource = - DataSource.builder() - .id(d.getId()) - .sourceId(d.getSourceId()) - .type(d.getType()) - .sqlQuery(d.getSqlQuery()) - .name(d.getName()) - .tableQuery(d.getTableQuery()) - .identifiers(getIdentify(d.getIdentifiers())) - .measures(getMeasureParams(d.getMeasures())) - .dimensions(getDimensions(d.getDimensions())) - .build(); + DataSource datasource = DataSource.builder().id(d.getId()).sourceId(d.getSourceId()) + .type(d.getType()).sqlQuery(d.getSqlQuery()).name(d.getName()) + .tableQuery(d.getTableQuery()).identifiers(getIdentify(d.getIdentifiers())) + .measures(getMeasureParams(d.getMeasures())) + .dimensions(getDimensions(d.getDimensions())).build(); datasource.setAggTime(getDataSourceAggTime(datasource.getDimensions())); if (Objects.nonNull(d.getModelSourceTypeEnum())) { datasource.setTimePartType(TimePartType.of(d.getModelSourceTypeEnum().name())); } if (Objects.nonNull(d.getFields()) && !CollectionUtils.isEmpty(d.getFields())) { - Set measures = - datasource.getMeasures().stream() - .map(mm -> mm.getName()) - .collect(Collectors.toSet()); + Set measures = datasource.getMeasures().stream().map(mm -> mm.getName()) + .collect(Collectors.toSet()); for (Field f : d.getFields()) { if (!measures.contains(f.getFieldName())) { - datasource - .getMeasures() - .add( - Measure.builder() - .expr(f.getFieldName()) - .name(f.getFieldName()) - .agg("") - .build()); + datasource.getMeasures().add(Measure.builder().expr(f.getFieldName()) + .name(f.getFieldName()).agg("").build()); } } } @@ -227,10 +202,9 @@ public class SemanticSchemaManager { } private static String getDataSourceAggTime(List dimensions) { - Optional timeDimension = - dimensions.stream() - .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) - .findFirst(); + Optional timeDimension = dimensions.stream() + .filter(d -> Constants.DIMENSION_TYPE_TIME.equalsIgnoreCase(d.getType())) + .findFirst(); if (timeDimension.isPresent() && Objects.nonNull(timeDimension.get().getDimensionTimeTypeParams())) { return timeDimension.get().getDimensionTimeTypeParams().getTimeGranularity(); @@ -340,8 +314,8 @@ public class SemanticSchemaManager { DimensionTimeTypeParamsTpl dimensionTimeTypeParamsTpl) { DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams(); if (dimensionTimeTypeParamsTpl != null) { - dimensionTimeTypeParams.setTimeGranularity( - dimensionTimeTypeParamsTpl.getTimeGranularity()); + dimensionTimeTypeParams + .setTimeGranularity(dimensionTimeTypeParamsTpl.getTimeGranularity()); dimensionTimeTypeParams.setIsPrimary(dimensionTimeTypeParamsTpl.getIsPrimary()); } return dimensionTimeTypeParams; @@ -358,38 +332,27 @@ public class SemanticSchemaManager { return identifies; } - private static List getJoinRelation( - List modelRelas, Map modelIdName) { + private static List getJoinRelation(List modelRelas, + Map modelIdName) { List joinRelations = new ArrayList<>(); - modelRelas.stream() - .forEach( - r -> { - if (modelIdName.containsKey(r.getFromModelId()) - && modelIdName.containsKey(r.getToModelId())) { - JoinRelation joinRelation = - JoinRelation.builder() - .left(modelIdName.get(r.getFromModelId())) - .right(modelIdName.get(r.getToModelId())) - .joinType(r.getJoinType()) - .build(); - List> conditions = new ArrayList<>(); - r.getJoinConditions().stream() - .forEach( - rr -> { - if (FilterOperatorEnum.isValueCompare( - rr.getOperator())) { - conditions.add( - Triple.of( - rr.getLeftField(), - rr.getOperator().getValue(), - rr.getRightField())); - } - }); - joinRelation.setId(r.getId()); - joinRelation.setJoinCondition(conditions); - joinRelations.add(joinRelation); - } - }); + modelRelas.stream().forEach(r -> { + if (modelIdName.containsKey(r.getFromModelId()) + && modelIdName.containsKey(r.getToModelId())) { + JoinRelation joinRelation = JoinRelation.builder() + .left(modelIdName.get(r.getFromModelId())) + .right(modelIdName.get(r.getToModelId())).joinType(r.getJoinType()).build(); + List> conditions = new ArrayList<>(); + r.getJoinConditions().stream().forEach(rr -> { + if (FilterOperatorEnum.isValueCompare(rr.getOperator())) { + conditions.add(Triple.of(rr.getLeftField(), rr.getOperator().getValue(), + rr.getRightField())); + } + }); + joinRelation.setId(r.getId()); + joinRelation.setJoinCondition(conditions); + joinRelations.add(joinRelation); + } + }); return joinRelations; } @@ -405,8 +368,7 @@ public class SemanticSchemaManager { String dataSourceName = datasourceYamlTpl.getName(); Optional> datasourceYamlTplMap = schema.getDatasource().entrySet().stream() - .filter(t -> t.getKey().equalsIgnoreCase(dataSourceName)) - .findFirst(); + .filter(t -> t.getKey().equalsIgnoreCase(dataSourceName)).findFirst(); if (datasourceYamlTplMap.isPresent()) { datasourceYamlTplMap.get().setValue(datasourceYamlTpl); } else { @@ -415,14 +377,12 @@ public class SemanticSchemaManager { } } - public static void update( - SemanticSchema schema, String datasourceBizName, List dimensionYamlTpls) - throws Exception { + public static void update(SemanticSchema schema, String datasourceBizName, + List dimensionYamlTpls) throws Exception { if (schema != null) { - Optional>> datasourceYamlTplMap = - schema.getDimension().entrySet().stream() - .filter(t -> t.getKey().equalsIgnoreCase(datasourceBizName)) - .findFirst(); + Optional>> datasourceYamlTplMap = schema + .getDimension().entrySet().stream() + .filter(t -> t.getKey().equalsIgnoreCase(datasourceBizName)).findFirst(); if (datasourceYamlTplMap.isPresent()) { updateDimension(dimensionYamlTpls, datasourceYamlTplMap.get().getValue()); } else { @@ -433,8 +393,8 @@ public class SemanticSchemaManager { } } - private static void updateDimension( - List dimensionYamlTpls, List dimensions) { + private static void updateDimension(List dimensionYamlTpls, + List dimensions) { if (CollectionUtils.isEmpty(dimensionYamlTpls)) { return; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DatabaseDOExample.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DatabaseDOExample.java index d3ce7cacb..8a387bd08 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DatabaseDOExample.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DatabaseDOExample.java @@ -139,8 +139,8 @@ public class DatabaseDOExample { criteria.add(new Criterion(condition, value)); } - protected void addCriterion( - String condition, Object value1, Object value2, String property) { + protected void addCriterion(String condition, Object value1, Object value2, + String property) { if (value1 == null || value2 == null) { throw new RuntimeException("Between values for " + property + " cannot be null"); } @@ -969,8 +969,8 @@ public class DatabaseDOExample { this(condition, value, null); } - protected Criterion( - String condition, Object value, Object secondValue, String typeHandler) { + protected Criterion(String condition, Object value, Object secondValue, + String typeHandler) { super(); this.condition = condition; this.value = value; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DomainDOExample.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DomainDOExample.java index bee796f64..b07ed7c51 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DomainDOExample.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/dataobject/DomainDOExample.java @@ -139,8 +139,8 @@ public class DomainDOExample { criteria.add(new Criterion(condition, value)); } - protected void addCriterion( - String condition, Object value1, Object value2, String property) { + protected void addCriterion(String condition, Object value1, Object value2, + String property) { if (value1 == null || value2 == null) { throw new RuntimeException("Between values for " + property + " cannot be null"); } @@ -1117,8 +1117,8 @@ public class DomainDOExample { this(condition, value, null); } - protected Criterion( - String condition, Object value, Object secondValue, String typeHandler) { + protected Criterion(String condition, Object value, Object secondValue, + String typeHandler) { super(); this.condition = condition; this.value = value; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/AppMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/AppMapper.java index 67fb28b93..85fd1691d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/AppMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/AppMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.AppDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface AppMapper extends BaseMapper {} +public interface AppMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CanvasDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CanvasDOMapper.java index f9589000d..e5dc70e34 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CanvasDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CanvasDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.CanvasDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface CanvasDOMapper extends BaseMapper {} +public interface CanvasDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ClassMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ClassMapper.java index 760f4a896..fd760d883 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ClassMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ClassMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ClassDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ClassMapper extends BaseMapper {} +public interface ClassMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CollectMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CollectMapper.java index 5753eb1e9..8516a6df0 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CollectMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/CollectMapper.java @@ -11,4 +11,5 @@ import org.apache.ibatis.annotations.Mapper; * @since 2023-11-09 03:49:33 */ @Mapper -public interface CollectMapper extends BaseMapper {} +public interface CollectMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DataSetDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DataSetDOMapper.java index 220d1a1fe..ac69011ef 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DataSetDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DataSetDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.DataSetDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DataSetDOMapper extends BaseMapper {} +public interface DataSetDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DatabaseDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DatabaseDOMapper.java index 21fb39401..aafa559d4 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DatabaseDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DatabaseDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.DatabaseDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DatabaseDOMapper extends BaseMapper {} +public interface DatabaseDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictConfMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictConfMapper.java index c9d7cee7a..b3e9f890b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictConfMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictConfMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.DictConfDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DictConfMapper extends BaseMapper {} +public interface DictConfMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictTaskMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictTaskMapper.java index cf4a54118..fb1fa2943 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictTaskMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DictTaskMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.DictTaskDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DictTaskMapper extends BaseMapper {} +public interface DictTaskMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DimensionDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DimensionDOMapper.java index bd5239aff..14aa50ba9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DimensionDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DimensionDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.DimensionDO import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DimensionDOMapper extends BaseMapper {} +public interface DimensionDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DomainDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DomainDOMapper.java index 4012fd1e6..af9756388 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DomainDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/DomainDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface DomainDOMapper extends BaseMapper {} +public interface DomainDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricDOMapper.java index 3318912ea..b54066f2f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.MetricDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface MetricDOMapper extends BaseMapper {} +public interface MetricDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricQueryDefaultConfigDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricQueryDefaultConfigDOMapper.java index 3f1db5bfa..b0fb4e7f8 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricQueryDefaultConfigDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/MetricQueryDefaultConfigDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.MetricQuery import org.apache.ibatis.annotations.Mapper; @Mapper -public interface MetricQueryDefaultConfigDOMapper extends BaseMapper {} +public interface MetricQueryDefaultConfigDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelDOMapper.java index 9521bcf7b..13d739478 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ModelDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ModelDOMapper extends BaseMapper {} +public interface ModelDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelRelaDOMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelRelaDOMapper.java index ae062d209..d9f8c8127 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelRelaDOMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/ModelRelaDOMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.ModelRelaDO import org.apache.ibatis.annotations.Mapper; @Mapper -public interface ModelRelaDOMapper extends BaseMapper {} +public interface ModelRelaDOMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/QueryRuleMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/QueryRuleMapper.java index 6556c0682..e0db70589 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/QueryRuleMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/QueryRuleMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.QueryRuleDO import org.apache.ibatis.annotations.Mapper; @Mapper -public interface QueryRuleMapper extends BaseMapper {} +public interface QueryRuleMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagMapper.java index 30388d123..c08867081 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.TagDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface TagMapper extends BaseMapper {} +public interface TagMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagObjectMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagObjectMapper.java index 72d64e394..c34b7300f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagObjectMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TagObjectMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.TagObjectDO import org.apache.ibatis.annotations.Mapper; @Mapper -public interface TagObjectMapper extends BaseMapper {} +public interface TagObjectMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TermMapper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TermMapper.java index dcfb85b1d..93e124a21 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TermMapper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/mapper/TermMapper.java @@ -5,4 +5,5 @@ import com.tencent.supersonic.headless.server.persistence.dataobject.TermDO; import org.apache.ibatis.annotations.Mapper; @Mapper -public interface TermMapper extends BaseMapper {} +public interface TermMapper extends BaseMapper { +} diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DateInfoRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DateInfoRepositoryImpl.java index bf3cee7b4..4ae5b1377 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DateInfoRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DateInfoRepositoryImpl.java @@ -26,7 +26,8 @@ public class DateInfoRepositoryImpl implements DateInfoRepository { private ObjectMapper mapper = new ObjectMapper(); - @Autowired private DateInfoMapper dateInfoMapper; + @Autowired + private DateInfoMapper dateInfoMapper; @Override public Integer upsertDateInfo(List dateInfoCommends) { @@ -36,22 +37,19 @@ public class DateInfoRepositoryImpl implements DateInfoRepository { return 0; } - dateInfoCommends.stream() - .forEach( - commend -> { - DateInfoDO dateInfoDO = new DateInfoDO(); - BeanUtils.copyProperties(commend, dateInfoDO); - try { - dateInfoDO.setUnavailableDateList( - mapper.writeValueAsString( - commend.getUnavailableDateList())); - dateInfoDO.setCreatedBy(Constants.ADMIN_LOWER); - dateInfoDO.setUpdatedBy(Constants.ADMIN_LOWER); - } catch (JsonProcessingException e) { - log.info("e,", e); - } - dateInfoDOList.add(dateInfoDO); - }); + dateInfoCommends.stream().forEach(commend -> { + DateInfoDO dateInfoDO = new DateInfoDO(); + BeanUtils.copyProperties(commend, dateInfoDO); + try { + dateInfoDO.setUnavailableDateList( + mapper.writeValueAsString(commend.getUnavailableDateList())); + dateInfoDO.setCreatedBy(Constants.ADMIN_LOWER); + dateInfoDO.setUpdatedBy(Constants.ADMIN_LOWER); + } catch (JsonProcessingException e) { + log.info("e,", e); + } + dateInfoDOList.add(dateInfoDO); + }); return batchUpsert(dateInfoDOList); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DictRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DictRepositoryImpl.java index a5b3f2ab7..1abaa39eb 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DictRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DictRepositoryImpl.java @@ -36,11 +36,8 @@ public class DictRepositoryImpl implements DictRepository { private final DictUtils dictConverter; private final DimensionService dimensionService; - public DictRepositoryImpl( - DictTaskMapper dictTaskMapper, - DictConfMapper dictConfMapper, - DictUtils dictConverter, - DimensionService dimensionService) { + public DictRepositoryImpl(DictTaskMapper dictTaskMapper, DictConfMapper dictConfMapper, + DictUtils dictConverter, DimensionService dimensionService) { this.dictTaskMapper = dictTaskMapper; this.dictConfMapper = dictConfMapper; this.dictConverter = dictConverter; @@ -90,11 +87,9 @@ public class DictRepositoryImpl implements DictRepository { QueryWrapper wrapper = new QueryWrapper<>(); wrapper.lambda().eq(DictTaskDO::getItemId, taskReq.getItemId()); wrapper.lambda().eq(DictTaskDO::getType, taskReq.getType()); - List dictTaskDOList = - dictTaskMapper.selectList(wrapper).stream() - .sorted(Comparator.comparing(DictTaskDO::getCreatedAt).reversed()) - .limit(dictTaskNum) - .collect(Collectors.toList()); + List dictTaskDOList = dictTaskMapper.selectList(wrapper).stream() + .sorted(Comparator.comparing(DictTaskDO::getCreatedAt).reversed()) + .limit(dictTaskNum).collect(Collectors.toList()); if (CollectionUtils.isEmpty(dictTaskDOList)) { return taskResp; } @@ -114,10 +109,8 @@ public class DictRepositoryImpl implements DictRepository { @Override public Long editDictConf(DictConfDO dictConfDO) { DictItemFilter filter = - DictItemFilter.builder() - .type(TypeEnums.valueOf(dictConfDO.getType())) - .itemId(dictConfDO.getItemId()) - .build(); + DictItemFilter.builder().type(TypeEnums.valueOf(dictConfDO.getType())) + .itemId(dictConfDO.getItemId()).build(); List dictConfDOList = getDictConfDOList(filter); if (CollectionUtils.isEmpty(dictConfDOList)) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DimensionRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DimensionRepositoryImpl.java index 931853c0c..4fe25fcda 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DimensionRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/DimensionRepositoryImpl.java @@ -17,8 +17,8 @@ public class DimensionRepositoryImpl implements DimensionRepository { private DimensionDOCustomMapper dimensionDOCustomMapper; - public DimensionRepositoryImpl( - DimensionDOMapper dimensionDOMapper, DimensionDOCustomMapper dimensionDOCustomMapper) { + public DimensionRepositoryImpl(DimensionDOMapper dimensionDOMapper, + DimensionDOCustomMapper dimensionDOCustomMapper) { this.dimensionDOMapper = dimensionDOMapper; this.dimensionDOCustomMapper = dimensionDOCustomMapper; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/MetricRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/MetricRepositoryImpl.java index f565cdadd..f79d8eabd 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/MetricRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/MetricRepositoryImpl.java @@ -22,8 +22,7 @@ public class MetricRepositoryImpl implements MetricRepository { private MetricQueryDefaultConfigDOMapper metricQueryDefaultConfigDOMapper; - public MetricRepositoryImpl( - MetricDOMapper metricDOMapper, + public MetricRepositoryImpl(MetricDOMapper metricDOMapper, MetricDOCustomMapper metricDOCustomMapper, MetricQueryDefaultConfigDOMapper metricQueryDefaultConfigDOMapper) { this.metricDOMapper = metricDOMapper; @@ -95,9 +94,7 @@ public class MetricRepositoryImpl implements MetricRepository { @Override public MetricQueryDefaultConfigDO getDefaultQueryConfig(Long metricId, String userName) { QueryWrapper queryWrapper = new QueryWrapper<>(); - queryWrapper - .lambda() - .eq(MetricQueryDefaultConfigDO::getMetricId, metricId) + queryWrapper.lambda().eq(MetricQueryDefaultConfigDO::getMetricId, metricId) .eq(MetricQueryDefaultConfigDO::getCreatedBy, userName); return metricQueryDefaultConfigDOMapper.selectOne(queryWrapper); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ModelRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ModelRepositoryImpl.java index a8b212458..8d481e92a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ModelRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/ModelRepositoryImpl.java @@ -19,8 +19,8 @@ public class ModelRepositoryImpl implements ModelRepository { private ModelDOCustomMapper modelDOCustomMapper; - public ModelRepositoryImpl( - ModelDOMapper modelDOMapper, ModelDOCustomMapper modelDOCustomMapper) { + public ModelRepositoryImpl(ModelDOMapper modelDOMapper, + ModelDOCustomMapper modelDOCustomMapper) { this.modelDOMapper = modelDOMapper; this.modelDOCustomMapper = modelDOCustomMapper; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/StatRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/StatRepositoryImpl.java index 0ebf9ef31..59f941f88 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/StatRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/StatRepositoryImpl.java @@ -45,34 +45,24 @@ public class StatRepositoryImpl implements StatRepository { List result = new ArrayList<>(); List statInfos = statMapper.getStatInfo(itemUseReq); Map map = new ConcurrentHashMap<>(); - statInfos.stream() - .forEach( - stat -> { - String dimensions = stat.getDimensions(); - String metrics = stat.getMetrics(); - if (Objects.nonNull(stat.getDataSetId())) { - updateStatMapInfo( - map, - dimensions, - TypeEnums.DIMENSION.name().toLowerCase(), - stat.getDataSetId()); - updateStatMapInfo( - map, - metrics, - TypeEnums.METRIC.name().toLowerCase(), - stat.getDataSetId()); - } - }); - map.forEach( - (k, v) -> { - Long classId = Long.parseLong(k.split(AT_SYMBOL + AT_SYMBOL)[0]); - String type = k.split(AT_SYMBOL + AT_SYMBOL)[1]; - String nameEn = k.split(AT_SYMBOL + AT_SYMBOL)[2]; - result.add(new ItemUseResp(classId, type, nameEn, v)); - }); + statInfos.stream().forEach(stat -> { + String dimensions = stat.getDimensions(); + String metrics = stat.getMetrics(); + if (Objects.nonNull(stat.getDataSetId())) { + updateStatMapInfo(map, dimensions, TypeEnums.DIMENSION.name().toLowerCase(), + stat.getDataSetId()); + updateStatMapInfo(map, metrics, TypeEnums.METRIC.name().toLowerCase(), + stat.getDataSetId()); + } + }); + map.forEach((k, v) -> { + Long classId = Long.parseLong(k.split(AT_SYMBOL + AT_SYMBOL)[0]); + String type = k.split(AT_SYMBOL + AT_SYMBOL)[1]; + String nameEn = k.split(AT_SYMBOL + AT_SYMBOL)[2]; + result.add(new ItemUseResp(classId, type, nameEn, v)); + }); - return result.stream() - .sorted(Comparator.comparing(ItemUseResp::getUseCnt).reversed()) + return result.stream().sorted(Comparator.comparing(ItemUseResp::getUseCnt).reversed()) .collect(Collectors.toList()); } @@ -81,24 +71,21 @@ public class StatRepositoryImpl implements StatRepository { return statMapper.getStatInfo(itemUseCommend); } - private void updateStatMapInfo( - Map map, String dimensions, String type, Long dataSetId) { + private void updateStatMapInfo(Map map, String dimensions, String type, + Long dataSetId) { if (StringUtils.isNotEmpty(dimensions)) { try { List dimensionList = mapper.readValue(dimensions, new TypeReference>() {}); - dimensionList.stream() - .forEach( - dimension -> { - String key = - dataSetId + AT_SYMBOL + AT_SYMBOL + type + AT_SYMBOL - + AT_SYMBOL + dimension; - if (map.containsKey(key)) { - map.put(key, map.get(key) + 1); - } else { - map.put(key, 1L); - } - }); + dimensionList.stream().forEach(dimension -> { + String key = dataSetId + AT_SYMBOL + AT_SYMBOL + type + AT_SYMBOL + AT_SYMBOL + + dimension; + if (map.containsKey(key)) { + map.put(key, map.get(key) + 1); + } else { + map.put(key, 1L); + } + }); } catch (Exception e) { log.warn("e:{}", e); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java index 6a70e254a..667faaed6 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/persistence/repository/impl/TagRepositoryImpl.java @@ -63,8 +63,8 @@ public class TagRepositoryImpl implements TagRepository { } if (Objects.nonNull(tagDeleteReq.getTagDefineType()) && CollectionUtils.isNotEmpty(tagDeleteReq.getItemIds())) { - tagCustomMapper.deleteBatchByType( - tagDeleteReq.getItemIds(), tagDeleteReq.getTagDefineType().name()); + tagCustomMapper.deleteBatchByType(tagDeleteReq.getItemIds(), + tagDeleteReq.getTagDefineType().name()); } } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java index 024d5aeb3..1e81baac2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/EntityInfoProcessor.java @@ -16,24 +16,18 @@ public class EntityInfoProcessor implements ResultProcessor { @Override public void process(ParseResp parseResp, ChatQueryContext chatQueryContext) { - parseResp - .getSelectedParses() - .forEach( - parseInfo -> { - String queryMode = parseInfo.getQueryMode(); - if (!QueryManager.isDetailQuery(queryMode) - && !QueryManager.isMetricQuery(queryMode)) { - return; - } + parseResp.getSelectedParses().forEach(parseInfo -> { + String queryMode = parseInfo.getQueryMode(); + if (!QueryManager.isDetailQuery(queryMode) && !QueryManager.isMetricQuery(queryMode)) { + return; + } - SemanticLayerService semanticService = - ContextUtils.getBean(SemanticLayerService.class); - DataSetSchema dataSetSchema = - semanticService.getDataSetSchema(parseInfo.getDataSetId()); - EntityInfo entityInfo = - semanticService.getEntityInfo( - parseInfo, dataSetSchema, chatQueryContext.getUser()); - parseInfo.setEntityInfo(entityInfo); - }); + SemanticLayerService semanticService = ContextUtils.getBean(SemanticLayerService.class); + DataSetSchema dataSetSchema = + semanticService.getDataSetSchema(parseInfo.getDataSetId()); + EntityInfo entityInfo = semanticService.getEntityInfo(parseInfo, dataSetSchema, + chatQueryContext.getUser()); + parseInfo.setEntityInfo(entityInfo); + }); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java index b27295ecb..c72be20c2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/processor/ParseInfoProcessor.java @@ -82,39 +82,33 @@ public class ParseInfoProcessor implements ResultProcessor { } else if (QueryType.DETAIL.equals(parseInfo.getQueryType())) { List selectFields = SqlSelectHelper.getSelectFields(s2SQL); List selectDimensions = filterDateField(dsSchema, selectFields); - parseInfo.setDimensions( - matchSchemaElements(selectDimensions, dsSchema.getDimensions())); + parseInfo + .setDimensions(matchSchemaElements(selectDimensions, dsSchema.getDimensions())); } } - private Set matchSchemaElements( - List allFields, Set elements) { - return elements.stream() - .filter( - schemaElement -> { - if (CollectionUtils.isEmpty(schemaElement.getAlias())) { - return allFields.contains(schemaElement.getName()); - } - Set allFieldsSet = new HashSet<>(allFields); - Set aliasSet = new HashSet<>(schemaElement.getAlias()); - List intersection = - allFieldsSet.stream() - .filter(aliasSet::contains) - .collect(Collectors.toList()); - return allFields.contains(schemaElement.getName()) - || !CollectionUtils.isEmpty(intersection); - }) - .collect(Collectors.toSet()); + private Set matchSchemaElements(List allFields, + Set elements) { + return elements.stream().filter(schemaElement -> { + if (CollectionUtils.isEmpty(schemaElement.getAlias())) { + return allFields.contains(schemaElement.getName()); + } + Set allFieldsSet = new HashSet<>(allFields); + Set aliasSet = new HashSet<>(schemaElement.getAlias()); + List intersection = + allFieldsSet.stream().filter(aliasSet::contains).collect(Collectors.toList()); + return allFields.contains(schemaElement.getName()) + || !CollectionUtils.isEmpty(intersection); + }).collect(Collectors.toSet()); } private List filterDateField(DataSetSchema dataSetSchema, List allFields) { - return allFields.stream() - .filter(entry -> !isPartitionDimension(dataSetSchema, entry)) + return allFields.stream().filter(entry -> !isPartitionDimension(dataSetSchema, entry)) .collect(Collectors.toList()); } - private List extractDimensionFilter( - DataSetSchema dsSchema, List fieldExpressions) { + private List extractDimensionFilter(DataSetSchema dsSchema, + List fieldExpressions) { Map fieldNameToElement = getNameToElement(dsSchema); List result = Lists.newArrayList(); @@ -139,15 +133,11 @@ public class ParseInfoProcessor implements ResultProcessor { return result; } - private DateConf extractDateFilter( - List fieldExpressions, DataSetSchema dataSetSchema) { - List dateExpressions = - fieldExpressions.stream() - .filter( - expression -> - isPartitionDimension( - dataSetSchema, expression.getFieldName())) - .collect(Collectors.toList()); + private DateConf extractDateFilter(List fieldExpressions, + DataSetSchema dataSetSchema) { + List dateExpressions = fieldExpressions.stream().filter( + expression -> isPartitionDimension(dataSetSchema, expression.getFieldName())) + .collect(Collectors.toList()); if (CollectionUtils.isEmpty(dateExpressions)) { return null; } @@ -164,20 +154,14 @@ public class ParseInfoProcessor implements ResultProcessor { dateInfo.setDateMode(DateConf.DateMode.BETWEEN); return dateInfo; } - if (containOperators( - firstExpression, - firstOperator, - FilterOperatorEnum.GREATER_THAN, + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.GREATER_THAN, FilterOperatorEnum.GREATER_THAN_EQUALS)) { dateInfo.setStartDate(firstExpression.getFieldValue().toString()); if (hasSecondDate(dateExpressions)) { dateInfo.setEndDate(dateExpressions.get(1).getFieldValue().toString()); } } - if (containOperators( - firstExpression, - firstOperator, - FilterOperatorEnum.MINOR_THAN, + if (containOperators(firstExpression, firstOperator, FilterOperatorEnum.MINOR_THAN, FilterOperatorEnum.MINOR_THAN_EQUALS)) { dateInfo.setEndDate(firstExpression.getFieldValue().toString()); if (hasSecondDate(dateExpressions)) { @@ -191,17 +175,14 @@ public class ParseInfoProcessor implements ResultProcessor { if (TimeDimensionEnum.containsTimeDimension(sqlFieldName)) { return true; } - if (Objects.isNull(dataSetSchema) - || Objects.isNull(dataSetSchema.getPartitionDimension()) + if (Objects.isNull(dataSetSchema) || Objects.isNull(dataSetSchema.getPartitionDimension()) || Objects.isNull(dataSetSchema.getPartitionDimension().getName())) { return false; } return sqlFieldName.equalsIgnoreCase(dataSetSchema.getPartitionDimension().getName()); } - private boolean containOperators( - FieldExpression expression, - FilterOperatorEnum firstOperator, + private boolean containOperators(FieldExpression expression, FilterOperatorEnum firstOperator, FilterOperatorEnum... operatorEnums) { return (Arrays.asList(operatorEnums).contains(firstOperator) && Objects.nonNull(expression.getFieldValue())); @@ -220,21 +201,16 @@ public class ParseInfoProcessor implements ResultProcessor { allElements.addAll(dimensions); allElements.addAll(metrics); // support alias - return allElements.stream() - .flatMap( - schemaElement -> { - Set> result = new HashSet<>(); - result.add(Pair.of(schemaElement.getName(), schemaElement)); - List aliasList = schemaElement.getAlias(); - if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) { - for (String alias : aliasList) { - result.add(Pair.of(alias, schemaElement)); - } - } - return result.stream(); - }) - .collect( - Collectors.toMap( - Pair::getLeft, Pair::getRight, (value1, value2) -> value2)); + return allElements.stream().flatMap(schemaElement -> { + Set> result = new HashSet<>(); + result.add(Pair.of(schemaElement.getName(), schemaElement)); + List aliasList = schemaElement.getAlias(); + if (!org.springframework.util.CollectionUtils.isEmpty(aliasList)) { + for (String alias : aliasList) { + result.add(Pair.of(alias, schemaElement)); + } + } + return result.stream(); + }).collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (value1, value2) -> value2)); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/AppController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/AppController.java index d812b1f27..5788163e5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/AppController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/AppController.java @@ -25,28 +25,27 @@ import org.springframework.web.bind.annotation.RestController; @RequestMapping("/api/semantic/app") public class AppController { - @Autowired private AppService appService; + @Autowired + private AppService appService; @PostMapping - public boolean save( - @RequestBody AppReq app, HttpServletRequest request, HttpServletResponse response) { + public boolean save(@RequestBody AppReq app, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); appService.save(app, user); return true; } @PutMapping - public boolean update( - @RequestBody AppReq app, HttpServletRequest request, HttpServletResponse response) { + public boolean update(@RequestBody AppReq app, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); appService.update(app, user); return true; } @PutMapping("/online/{id}") - public boolean online( - @PathVariable("id") Integer id, - HttpServletRequest request, + public boolean online(@PathVariable("id") Integer id, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); appService.online(id, user); @@ -54,9 +53,7 @@ public class AppController { } @PutMapping("/offline/{id}") - public boolean offline( - @PathVariable("id") Integer id, - HttpServletRequest request, + public boolean offline(@PathVariable("id") Integer id, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); appService.offline(id, user); @@ -64,9 +61,7 @@ public class AppController { } @DeleteMapping("/{id}") - public boolean delete( - @PathVariable("id") Integer id, - HttpServletRequest request, + public boolean delete(@PathVariable("id") Integer id, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); appService.delete(id, user); @@ -74,19 +69,15 @@ public class AppController { } @GetMapping("/{id}") - public AppDetailResp getApp( - @PathVariable("id") Integer id, - HttpServletRequest request, + public AppDetailResp getApp(@PathVariable("id") Integer id, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return appService.getApp(id, user); } @PostMapping("/page") - public PageInfo pageApp( - @RequestBody AppQueryReq appQueryReq, - HttpServletRequest request, - HttpServletResponse response) { + public PageInfo pageApp(@RequestBody AppQueryReq appQueryReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return appService.pageApp(appQueryReq, user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CanvasController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CanvasController.java index 601ac5e44..a86891f20 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CanvasController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CanvasController.java @@ -24,13 +24,12 @@ import java.util.List; @RequestMapping("/api/semantic/viewInfo") public class CanvasController { - @Autowired private CanvasService canvasService; + @Autowired + private CanvasService canvasService; @PostMapping("/createOrUpdateViewInfo") - public CanvasDO createOrUpdateCanvas( - @RequestBody CanvasReq canvasReq, - HttpServletRequest request, - HttpServletResponse response) { + public CanvasDO createOrUpdateCanvas(@RequestBody CanvasReq canvasReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return canvasService.createOrUpdateCanvas(canvasReq, user); } @@ -46,10 +45,8 @@ public class CanvasController { } @GetMapping("/getDomainSchemaRela/{domainId}") - public List getDomainSchema( - @PathVariable("domainId") Long domainId, - HttpServletRequest request, - HttpServletResponse response) { + public List getDomainSchema(@PathVariable("domainId") Long domainId, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return canvasService.getCanvasSchema(domainId, user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ClassController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ClassController.java index 3b3aecf71..abaa7e64c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ClassController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ClassController.java @@ -40,9 +40,7 @@ public class ClassController { * @return */ @PostMapping("/create") - public ClassResp create( - @RequestBody @Valid ClassReq classReq, - HttpServletRequest request, + public ClassResp create(@RequestBody @Valid ClassReq classReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return classService.create(classReq, user); @@ -57,9 +55,7 @@ public class ClassController { * @return */ @PutMapping("/update") - public ClassResp update( - @RequestBody @Valid ClassReq classReq, - HttpServletRequest request, + public ClassResp update(@RequestBody @Valid ClassReq classReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return classService.update(classReq, user); @@ -75,12 +71,8 @@ public class ClassController { * @throws Exception */ @DeleteMapping("delete/{id}/{force}") - public Boolean delete( - @PathVariable("id") Long id, - @PathVariable("force") Boolean force, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Boolean delete(@PathVariable("id") Long id, @PathVariable("force") Boolean force, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return classService.delete(id, force, user); } @@ -95,9 +87,7 @@ public class ClassController { * @throws Exception */ @GetMapping("delete/{id}/{force}") - public List get( - @RequestBody @Valid ClassFilter filter, - HttpServletRequest request, + public List get(@RequestBody @Valid ClassFilter filter, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return classService.getClassList(filter, user); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java index dcd4eade8..b9ca30394 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/CollectController.java @@ -26,27 +26,23 @@ public class CollectController { } @PostMapping("/createCollectionIndicators") - public boolean createCollectionIndicators( - @RequestBody CollectDO collectDO, - HttpServletRequest request, - HttpServletResponse response) { + public boolean createCollectionIndicators(@RequestBody CollectDO collectDO, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return collectService.collect(user, collectDO); } @Deprecated @DeleteMapping("/deleteCollectionIndicators/{id}") - public boolean deleteCollectionIndicators( - @PathVariable Long id, HttpServletRequest request, HttpServletResponse response) { + public boolean deleteCollectionIndicators(@PathVariable Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return collectService.unCollect(user, id); } @PostMapping("/deleteCollectionIndicators") - public boolean deleteCollectionIndicators( - @RequestBody CollectDO collectDO, - HttpServletRequest request, - HttpServletResponse response) { + public boolean deleteCollectionIndicators(@RequestBody CollectDO collectDO, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return collectService.unCollect(user, collectDO); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DataSetController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DataSetController.java index da3046010..238723414 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DataSetController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DataSetController.java @@ -26,21 +26,18 @@ import java.util.List; @RequestMapping("/api/semantic/dataSet") public class DataSetController { - @Autowired private DataSetService dataSetService; + @Autowired + private DataSetService dataSetService; @PostMapping - public DataSetResp save( - @RequestBody DataSetReq dataSetReq, - HttpServletRequest request, + public DataSetResp save(@RequestBody DataSetReq dataSetReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return dataSetService.save(dataSetReq, user); } @PutMapping - public DataSetResp update( - @RequestBody DataSetReq dataSetReq, - HttpServletRequest request, + public DataSetResp update(@RequestBody DataSetReq dataSetReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return dataSetService.update(dataSetReq, user); @@ -59,8 +56,8 @@ public class DataSetController { } @DeleteMapping("/{id}") - public Boolean delete( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public Boolean delete(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); dataSetService.delete(id, user); return true; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java index 274d25a2d..30969dfe5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DatabaseController.java @@ -37,33 +37,29 @@ public class DatabaseController { } @PostMapping("/testConnect") - public boolean testConnect( - @RequestBody DatabaseReq databaseReq, - HttpServletRequest request, + public boolean testConnect(@RequestBody DatabaseReq databaseReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return databaseService.testConnect(databaseReq, user); } @PostMapping("/createOrUpdateDatabase") - public DatabaseResp createOrUpdateDatabase( - @RequestBody DatabaseReq databaseReq, - HttpServletRequest request, - HttpServletResponse response) { + public DatabaseResp createOrUpdateDatabase(@RequestBody DatabaseReq databaseReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return databaseService.createOrUpdateDatabase(databaseReq, user); } @GetMapping("/{id}") - public DatabaseResp getDatabase( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public DatabaseResp getDatabase(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return databaseService.getDatabase(id, user); } @GetMapping("/getDatabaseList") - public List getDatabaseList( - HttpServletRequest request, HttpServletResponse response) { + public List getDatabaseList(HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return databaseService.getDatabaseList(user); } @@ -75,10 +71,8 @@ public class DatabaseController { } @PostMapping("/executeSql") - public SemanticQueryResp executeSql( - @RequestBody SqlExecuteReq sqlExecuteReq, - HttpServletRequest request, - HttpServletResponse response) { + public SemanticQueryResp executeSql(@RequestBody SqlExecuteReq sqlExecuteReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return databaseService.executeSql(sqlExecuteReq, sqlExecuteReq.getId(), user); } @@ -89,17 +83,14 @@ public class DatabaseController { } @RequestMapping("/getTables") - public List getTables( - @RequestParam("databaseId") Long databaseId, @RequestParam("db") String db) - throws SQLException { + public List getTables(@RequestParam("databaseId") Long databaseId, + @RequestParam("db") String db) throws SQLException { return databaseService.getTables(databaseId, db); } @RequestMapping("/getColumnsByName") - public List getColumnsByName( - @RequestParam("databaseId") Long databaseId, - @RequestParam("db") String db, - @RequestParam("table") String table) + public List getColumnsByName(@RequestParam("databaseId") Long databaseId, + @RequestParam("db") String db, @RequestParam("table") String table) throws SQLException { return databaseService.getColumns(databaseId, db, table); } @@ -110,8 +101,8 @@ public class DatabaseController { } @GetMapping("/getDatabaseParameters") - public Map> getDatabaseParameters( - HttpServletRequest request, HttpServletResponse response) { + public Map> getDatabaseParameters(HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return databaseService.getDatabaseParameters(user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DimensionController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DimensionController.java index 43488ae82..81d5e3877 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DimensionController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DimensionController.java @@ -34,9 +34,11 @@ import java.util.List; @RequestMapping("/api/semantic/dimension") public class DimensionController { - @Autowired private DimensionService dimensionService; + @Autowired + private DimensionService dimensionService; - @Autowired private SemanticLayerService queryService; + @Autowired + private SemanticLayerService queryService; /** * 创建维度 @@ -44,60 +46,46 @@ public class DimensionController { * @param dimensionReq */ @PostMapping("/createDimension") - public DimensionResp createDimension( - @RequestBody DimensionReq dimensionReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public DimensionResp createDimension(@RequestBody DimensionReq dimensionReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return dimensionService.createDimension(dimensionReq, user); } @PostMapping("/updateDimension") - public Boolean updateDimension( - @RequestBody DimensionReq dimensionReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Boolean updateDimension(@RequestBody DimensionReq dimensionReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); dimensionService.updateDimension(dimensionReq, user); return true; } @PostMapping("/batchUpdateStatus") - public Boolean batchUpdateStatus( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUpdateStatus(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); dimensionService.batchUpdateStatus(metaBatchReq, user); return true; } @PostMapping("/batchUpdateSensitiveLevel") - public Boolean batchUpdateSensitiveLevel( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUpdateSensitiveLevel(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); dimensionService.batchUpdateSensitiveLevel(metaBatchReq, user); return true; } @PostMapping("/mockDimensionAlias") - public List mockMetricAlias( - @RequestBody DimensionReq dimensionReq, - HttpServletRequest request, - HttpServletResponse response) { + public List mockMetricAlias(@RequestBody DimensionReq dimensionReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return dimensionService.mockAlias(dimensionReq, "dimension", user); } @PostMapping("/mockDimensionValuesAlias") - public List mockDimensionValuesAlias( - @RequestBody DimensionReq dimensionReq, - HttpServletRequest request, - HttpServletResponse response) { + public List mockDimensionValuesAlias(@RequestBody DimensionReq dimensionReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return dimensionService.mockDimensionValueAlias(dimensionReq, user); } @@ -115,8 +103,7 @@ public class DimensionController { } @GetMapping("/{modelId}/{dimensionName}") - public DimensionResp getDimensionDescByNameAndId( - @PathVariable("modelId") Long modelId, + public DimensionResp getDimensionDescByNameAndId(@PathVariable("modelId") Long modelId, @PathVariable("dimensionName") String dimensionBizName) { return dimensionService.getDimension(dimensionBizName, modelId); } @@ -127,17 +114,15 @@ public class DimensionController { } @PostMapping("/queryDimValue") - public SemanticQueryResp queryDimValue( - @RequestBody DimensionValueReq dimensionValueReq, - HttpServletRequest request, - HttpServletResponse response) { + public SemanticQueryResp queryDimValue(@RequestBody DimensionValueReq dimensionValueReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return queryService.queryDimensionValue(dimensionValueReq, user); } @DeleteMapping("deleteDimension/{id}") - public Boolean deleteDimension( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public Boolean deleteDimension(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); dimensionService.deleteDimension(id, user); return true; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DomainController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DomainController.java index 32416c655..b55daec1f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DomainController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/DomainController.java @@ -32,19 +32,15 @@ public class DomainController { } @PostMapping("/createDomain") - public DomainResp createDomain( - @RequestBody DomainReq domainReq, - HttpServletRequest request, + public DomainResp createDomain(@RequestBody DomainReq domainReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return domainService.createDomain(domainReq, user); } @PostMapping("/updateDomain") - public DomainResp updateDomain( - @RequestBody DomainUpdateReq domainUpdateReq, - HttpServletRequest request, - HttpServletResponse response) { + public DomainResp updateDomain(@RequestBody DomainUpdateReq domainUpdateReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return domainService.updateDomain(domainUpdateReq, user); } @@ -56,8 +52,8 @@ public class DomainController { } @GetMapping("/getDomainList") - public List getDomainList( - HttpServletRequest request, HttpServletResponse response) { + public List getDomainList(HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return domainService.getDomainListWithAdminAuth(user); } @@ -69,9 +65,7 @@ public class DomainController { @GetMapping("/getDomainListByIds/{domainIds}") public List getDomainListByIds(@PathVariable("domainIds") String domainIds) { - return domainService.getDomainList( - Arrays.stream(domainIds.split(",")) - .map(Long::parseLong) - .collect(Collectors.toList())); + return domainService.getDomainList(Arrays.stream(domainIds.split(",")).map(Long::parseLong) + .collect(Collectors.toList())); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java index 53e9251dc..696f2ebcb 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/KnowledgeController.java @@ -34,17 +34,23 @@ import java.util.List; @RequestMapping("/api/semantic/knowledge") public class KnowledgeController { - @Autowired private DictTaskService taskService; + @Autowired + private DictTaskService taskService; - @Autowired private DictConfService confService; + @Autowired + private DictConfService confService; - @Autowired private MetaEmbeddingTask metaEmbeddingTask; + @Autowired + private MetaEmbeddingTask metaEmbeddingTask; - @Autowired private DictionaryReloadTask dictionaryReloadTask; + @Autowired + private DictionaryReloadTask dictionaryReloadTask; - @Autowired private ExemplarService exemplarService; + @Autowired + private ExemplarService exemplarService; - @Autowired private EmbeddingService embeddingService; + @Autowired + private EmbeddingService embeddingService; /** * addDictConf-新增item的字典配置 Add configuration information for dictionary entries @@ -52,10 +58,8 @@ public class KnowledgeController { * @param dictItemReq */ @PostMapping("/conf") - public DictItemResp addDictConf( - @RequestBody @Valid DictItemReq dictItemReq, - HttpServletRequest request, - HttpServletResponse response) { + public DictItemResp addDictConf(@RequestBody @Valid DictItemReq dictItemReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return confService.addDictConf(dictItemReq, user); } @@ -66,10 +70,8 @@ public class KnowledgeController { * @param dictItemReq */ @PutMapping("/conf") - public DictItemResp editDictConf( - @RequestBody @Valid DictItemReq dictItemReq, - HttpServletRequest request, - HttpServletResponse response) { + public DictItemResp editDictConf(@RequestBody @Valid DictItemReq dictItemReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return confService.editDictConf(dictItemReq, user); } @@ -80,10 +82,8 @@ public class KnowledgeController { * @param filter */ @PostMapping("/conf/query") - public List queryDictConf( - @RequestBody @Valid DictItemFilter filter, - HttpServletRequest request, - HttpServletResponse response) { + public List queryDictConf(@RequestBody @Valid DictItemFilter filter, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return confService.queryDictConf(filter, user); } @@ -94,9 +94,7 @@ public class KnowledgeController { * @param taskReq */ @PostMapping("/task") - public Long addDictTask( - @RequestBody DictSingleTaskReq taskReq, - HttpServletRequest request, + public Long addDictTask(@RequestBody DictSingleTaskReq taskReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return taskService.addDictTask(taskReq, user); @@ -108,9 +106,7 @@ public class KnowledgeController { * @param taskReq */ @PutMapping("/task/delete") - public Long deleteDictTask( - @RequestBody DictSingleTaskReq taskReq, - HttpServletRequest request, + public Long deleteDictTask(@RequestBody DictSingleTaskReq taskReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return taskService.deleteDictTask(taskReq, user); @@ -128,10 +124,8 @@ public class KnowledgeController { * @param taskReq */ @PostMapping("/task/search") - public DictTaskResp queryLatestDictTask( - @RequestBody DictSingleTaskReq taskReq, - HttpServletRequest request, - HttpServletResponse response) { + public DictTaskResp queryLatestDictTask(@RequestBody DictSingleTaskReq taskReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return taskService.queryLatestDictTask(taskReq, user); } @@ -161,10 +155,8 @@ public class KnowledgeController { * @param dictValueReq */ @PostMapping("/dict/data") - public PageInfo queryDictValue( - @RequestBody @Valid DictValueReq dictValueReq, - HttpServletRequest request, - HttpServletResponse response) { + public PageInfo queryDictValue(@RequestBody @Valid DictValueReq dictValueReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return taskService.queryDictValue(dictValueReq, user); } @@ -175,10 +167,8 @@ public class KnowledgeController { * @param dictValueReq */ @PostMapping("/dict/file") - public String queryDictFilePath( - @RequestBody @Valid DictValueReq dictValueReq, - HttpServletRequest request, - HttpServletResponse response) { + public String queryDictFilePath(@RequestBody @Valid DictValueReq dictValueReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return taskService.queryDictFilePath(dictValueReq, user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/MetricController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/MetricController.java index 0ecd54156..9ddc89c18 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/MetricController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/MetricController.java @@ -40,39 +40,29 @@ public class MetricController { } @PostMapping("/createMetric") - public MetricResp createMetric( - @RequestBody MetricReq metricReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public MetricResp createMetric(@RequestBody MetricReq metricReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return metricService.createMetric(metricReq, user); } @PostMapping("/updateMetric") - public MetricResp updateMetric( - @RequestBody MetricReq metricReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public MetricResp updateMetric(@RequestBody MetricReq metricReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return metricService.updateMetric(metricReq, user); } @PostMapping("/batchUpdateStatus") - public Boolean batchUpdateStatus( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUpdateStatus(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); metricService.batchUpdateStatus(metaBatchReq, user); return true; } @PostMapping("/batchPublish") - public Boolean batchPublish( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, + public Boolean batchPublish(@RequestBody MetaBatchReq metaBatchReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); metricService.batchPublish(metaBatchReq.getIds(), user); @@ -80,40 +70,32 @@ public class MetricController { } @PostMapping("/batchUnPublish") - public Boolean batchUnPublish( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUnPublish(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); metricService.batchUnPublish(metaBatchReq.getIds(), user); return true; } @PostMapping("/batchUpdateClassifications") - public Boolean batchUpdateClassifications( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUpdateClassifications(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); metricService.batchUpdateClassifications(metaBatchReq, user); return true; } @PostMapping("/batchUpdateSensitiveLevel") - public Boolean batchUpdateSensitiveLevel( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUpdateSensitiveLevel(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); metricService.batchUpdateSensitiveLevel(metaBatchReq, user); return true; } @PostMapping("/mockMetricAlias") - public List mockMetricAlias( - @RequestBody MetricBaseReq metricReq, - HttpServletRequest request, - HttpServletResponse response) { + public List mockMetricAlias(@RequestBody MetricBaseReq metricReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return metricService.mockAlias(metricReq, "indicator", user); } @@ -130,32 +112,29 @@ public class MetricController { } @PostMapping("/queryMetric") - public PageInfo queryMetric( - @RequestBody PageMetricReq pageMetricReq, - HttpServletRequest request, - HttpServletResponse response) { + public PageInfo queryMetric(@RequestBody PageMetricReq pageMetricReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return metricService.queryMetricMarket(pageMetricReq, user); } @Deprecated @GetMapping("getMetric/{modelId}/{bizName}") - public MetricResp getMetric( - @PathVariable("modelId") Long modelId, @PathVariable("bizName") String bizName) { + public MetricResp getMetric(@PathVariable("modelId") Long modelId, + @PathVariable("bizName") String bizName) { return metricService.getMetric(modelId, bizName); } @GetMapping("getMetric/{id}") - public MetricResp getMetric( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public MetricResp getMetric(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return metricService.getMetric(id, user); } @DeleteMapping("deleteMetric/{id}") - public Boolean deleteMetric( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) - throws Exception { + public Boolean deleteMetric(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); metricService.deleteMetric(id, user); return true; @@ -186,8 +165,7 @@ public class MetricController { @PostMapping("/saveMetricQueryDefaultConfig") public boolean saveMetricQueryDefaultConfig( - @RequestBody MetricQueryDefaultConfig queryDefaultConfig, - HttpServletRequest request, + @RequestBody MetricQueryDefaultConfig queryDefaultConfig, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); metricService.saveMetricQueryDefaultConfig(queryDefaultConfig, user); @@ -196,8 +174,7 @@ public class MetricController { @RequestMapping("getMetricQueryDefaultConfig/{metricId}") public MetricQueryDefaultConfig getMetricQueryDefaultConfig( - @PathVariable("metricId") Long metricId, - HttpServletRequest request, + @PathVariable("metricId") Long metricId, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return metricService.getMetricQueryDefaultConfig(metricId, user); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java index 2fae146cc..35ea5680c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelController.java @@ -39,31 +39,23 @@ public class ModelController { } @PostMapping("/createModel") - public Boolean createModel( - @RequestBody ModelReq modelReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Boolean createModel(@RequestBody ModelReq modelReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); modelService.createModel(modelReq, user); return true; } @PostMapping("/updateModel") - public Boolean updateModel( - @RequestBody ModelReq modelReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Boolean updateModel(@RequestBody ModelReq modelReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); modelService.updateModel(modelReq, user); return true; } @DeleteMapping("/deleteModel/{modelId}") - public Boolean deleteModel( - @PathVariable("modelId") Long modelId, - HttpServletRequest request, + public Boolean deleteModel(@PathVariable("modelId") Long modelId, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); modelService.deleteModel(modelId, user); @@ -71,10 +63,8 @@ public class ModelController { } @GetMapping("/getModelList/{domainId}") - public List getModelList( - @PathVariable("domainId") Long domainId, - HttpServletRequest request, - HttpServletResponse response) { + public List getModelList(@PathVariable("domainId") Long domainId, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return modelService.getModelListWithAuth(user, domainId, AuthType.ADMIN); } @@ -86,10 +76,8 @@ public class ModelController { @GetMapping("/getModelListByIds/{modelIds}") public List getModelListByIds(@PathVariable("modelIds") String modelIds) { - List ids = - Arrays.stream(modelIds.split(",")) - .map(Long::parseLong) - .collect(Collectors.toList()); + List ids = Arrays.stream(modelIds.split(",")).map(Long::parseLong) + .collect(Collectors.toList()); ModelFilter modelFilter = new ModelFilter(); modelFilter.setIds(ids); return modelService.getModelList(modelFilter); @@ -106,10 +94,8 @@ public class ModelController { } @PostMapping("/batchUpdateStatus") - public Boolean batchUpdateStatus( - @RequestBody MetaBatchReq metaBatchReq, - HttpServletRequest request, - HttpServletResponse response) { + public Boolean batchUpdateStatus(@RequestBody MetaBatchReq metaBatchReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); modelService.batchUpdateStatus(metaBatchReq, user); return true; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelRelaController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelRelaController.java index a512af0ff..0cc2cc553 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelRelaController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/ModelRelaController.java @@ -19,7 +19,8 @@ import java.util.List; @RequestMapping("/api/semantic/modelRela") public class ModelRelaController { - @Autowired private ModelRelaService modelRelaService; + @Autowired + private ModelRelaService modelRelaService; @PostMapping public boolean save(@RequestBody ModelRela modelRela, User user) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryRuleController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryRuleController.java index 69e5106eb..e35902aa3 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryRuleController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/QueryRuleController.java @@ -39,10 +39,8 @@ public class QueryRuleController { * @throws Exception */ @PostMapping("/create") - public QueryRuleResp create( - @RequestBody @Validated QueryRuleReq queryRuleReq, - HttpServletRequest request, - HttpServletResponse response) { + public QueryRuleResp create(@RequestBody @Validated QueryRuleReq queryRuleReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return queryRuleService.addQueryRule(queryRuleReq, user); } @@ -57,10 +55,8 @@ public class QueryRuleController { * @throws Exception */ @PostMapping("/update") - public QueryRuleResp update( - @RequestBody @Validated QueryRuleReq queryRuleReq, - HttpServletRequest request, - HttpServletResponse response) { + public QueryRuleResp update(@RequestBody @Validated QueryRuleReq queryRuleReq, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return queryRuleService.updateQueryRule(queryRuleReq, user); } @@ -74,8 +70,8 @@ public class QueryRuleController { * @return */ @DeleteMapping("delete/{id}") - public Boolean delete( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public Boolean delete(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return queryRuleService.dropQueryRule(id, user); } @@ -88,10 +84,8 @@ public class QueryRuleController { * @return */ @PostMapping("query") - public List query( - @RequestBody @Validated QueryRuleFilter queryRuleFilter, - HttpServletRequest request, - HttpServletResponse response) { + public List query(@RequestBody @Validated QueryRuleFilter queryRuleFilter, + HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return queryRuleService.getQueryRuleList(queryRuleFilter, user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/SchemaController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/SchemaController.java index 72ebf15a6..2f3ef0c79 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/SchemaController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/SchemaController.java @@ -21,20 +21,19 @@ import java.util.List; @RequestMapping("/api/semantic/schema") public class SchemaController { - @Autowired private SchemaService schemaService; + @Autowired + private SchemaService schemaService; @GetMapping("/domain/list") - public List getDomainList( - HttpServletRequest request, HttpServletResponse response) { + public List getDomainList(HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return schemaService.getDomainList(user); } @GetMapping("/model/list") - public List getModelList( - @RequestParam("domainId") Long domainId, - @RequestParam("authType") String authType, - HttpServletRequest request, + public List getModelList(@RequestParam("domainId") Long domainId, + @RequestParam("authType") String authType, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return schemaService.getModelList(user, AuthType.valueOf(authType), domainId); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java index f1c1b4098..864356313 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagController.java @@ -49,9 +49,8 @@ public class TagController { * @throws Exception */ @PostMapping("/create") - public TagResp create( - @RequestBody TagReq tagReq, HttpServletRequest request, HttpServletResponse response) - throws Exception { + public TagResp create(@RequestBody TagReq tagReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagMetaService.create(tagReq, user); } @@ -66,11 +65,8 @@ public class TagController { * @throws Exception */ @PostMapping("/create/batch") - public Integer createBatch( - @RequestBody @Valid List tagReqList, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Integer createBatch(@RequestBody @Valid List tagReqList, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagMetaService.createBatch(tagReqList, user); } @@ -85,11 +81,8 @@ public class TagController { * @throws Exception */ @PostMapping("/delete/batch") - public Boolean deleteBatch( - @RequestBody @Valid List tagDeleteReqList, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public Boolean deleteBatch(@RequestBody @Valid List tagDeleteReqList, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagMetaService.deleteBatch(tagDeleteReqList, user); } @@ -104,8 +97,8 @@ public class TagController { * @throws Exception */ @DeleteMapping("delete/{id}") - public Boolean delete( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public Boolean delete(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); tagMetaService.delete(id, user); return true; @@ -120,8 +113,8 @@ public class TagController { * @return */ @GetMapping("getTag/{id}") - public TagResp getTag( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) { + public TagResp getTag(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return tagMetaService.getTag(id, user); } @@ -147,11 +140,8 @@ public class TagController { * @throws Exception */ @PostMapping("/value/distribution") - public ItemValueResp queryTagValue( - @RequestBody ItemValueReq itemValueReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public ItemValueResp queryTagValue(@RequestBody ItemValueReq itemValueReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagQueryService.queryTagValue(itemValueReq, user); } @@ -166,11 +156,8 @@ public class TagController { * @throws Exception */ @PostMapping("/queryTag/market") - public PageInfo queryTagMarketPage( - @RequestBody TagFilterPageReq tagMarketPageReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public PageInfo queryTagMarketPage(@RequestBody TagFilterPageReq tagMarketPageReq, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagMetaService.queryTagMarketPage(tagMarketPageReq, user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagObjectController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagObjectController.java index aed1c99b7..0ce72b4f7 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagObjectController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TagObjectController.java @@ -38,11 +38,8 @@ public class TagObjectController { * @throws Exception */ @PostMapping("/create") - public TagObjectResp create( - @RequestBody TagObjectReq tagObjectReq, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public TagObjectResp create(@RequestBody TagObjectReq tagObjectReq, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagObjectService.create(tagObjectReq, user); } @@ -56,9 +53,7 @@ public class TagObjectController { * @return */ @PostMapping("/update") - public TagObjectResp update( - @RequestBody TagObjectReq tagObjectReq, - HttpServletRequest request, + public TagObjectResp update(@RequestBody TagObjectReq tagObjectReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); return tagObjectService.update(tagObjectReq, user); @@ -74,9 +69,8 @@ public class TagObjectController { * @throws Exception */ @DeleteMapping("delete/{id}") - public Boolean delete( - @PathVariable("id") Long id, HttpServletRequest request, HttpServletResponse response) - throws Exception { + public Boolean delete(@PathVariable("id") Long id, HttpServletRequest request, + HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); tagObjectService.delete(id, user, true); return true; @@ -92,11 +86,8 @@ public class TagObjectController { * @throws Exception */ @PostMapping("/query") - public List queryTagObject( - @RequestBody TagObjectFilter filter, - HttpServletRequest request, - HttpServletResponse response) - throws Exception { + public List queryTagObject(@RequestBody TagObjectFilter filter, + HttpServletRequest request, HttpServletResponse response) throws Exception { User user = UserHolder.findUser(request, response); return tagObjectService.getTagObjects(filter, user); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java index 39dfe64a2..43c693e53 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/rest/TermController.java @@ -25,12 +25,11 @@ import java.util.List; @RequestMapping("/api/semantic/term") public class TermController { - @Autowired private TermService termService; + @Autowired + private TermService termService; @PostMapping("/saveOrUpdate") - public boolean saveOrUpdate( - @RequestBody TermReq termReq, - HttpServletRequest request, + public boolean saveOrUpdate(@RequestBody TermReq termReq, HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); termService.saveOrUpdate(termReq, user); @@ -38,8 +37,7 @@ public class TermController { } @GetMapping - public List getTerms( - @RequestParam("domainId") Long domainId, + public List getTerms(@RequestParam("domainId") Long domainId, @RequestParam(name = "queryKey", required = false) String queryKey) { return termService.getTerms(domainId, queryKey); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DownloadService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DownloadService.java index 3bd0f404a..a9714fb5a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DownloadService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/DownloadService.java @@ -8,9 +8,8 @@ import com.tencent.supersonic.headless.api.pojo.request.DownloadMetricReq; public interface DownloadService { - void downloadByStruct( - DownloadMetricReq downloadStructReq, User user, HttpServletResponse response) - throws Exception; + void downloadByStruct(DownloadMetricReq downloadStructReq, User user, + HttpServletResponse response) throws Exception; void batchDownload(BatchDownloadReq batchDownloadReq, User user, HttpServletResponse response) throws Exception; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/SchemaService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/SchemaService.java index e6c0222a0..a68c430e8 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/SchemaService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/SchemaService.java @@ -57,11 +57,9 @@ public interface SchemaService { List getDomainDataSetTree(); - void getSchemaYamlTpl( - SemanticSchemaResp semanticSchemaResp, + void getSchemaYamlTpl(SemanticSchemaResp semanticSchemaResp, Map> dimensionYamlMap, - List dataModelYamlTplList, - List metricYamlTplList, + List dataModelYamlTplList, List metricYamlTplList, Map modelIdName); ItemDateResp getItemDate(ItemDateFilter dimension, ItemDateFilter metric); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/AppServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/AppServiceImpl.java index e7d9e9451..4c73e0961 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/AppServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/AppServiceImpl.java @@ -44,8 +44,8 @@ public class AppServiceImpl extends ServiceImpl implements App private DimensionService dimensionService; - public AppServiceImpl( - AppMapper appMapper, MetricService metricService, DimensionService dimensionService) { + public AppServiceImpl(AppMapper appMapper, MetricService metricService, + DimensionService dimensionService) { this.appMapper = appMapper; this.metricService = metricService; this.dimensionService = dimensionService; @@ -110,16 +110,13 @@ public class AppServiceImpl extends ServiceImpl implements App PageHelper.startPage(appQueryReq.getCurrent(), appQueryReq.getPageSize()) .doSelectPageInfo(() -> queryApp(appQueryReq)); PageInfo appPageInfo = PageUtils.pageInfo2PageInfoVo(appDOPageInfo); - Map metricResps = - metricService.getMetrics(new MetaFilter()).stream() - .collect(Collectors.toMap(MetricResp::getId, m -> m)); - Map dimensionResps = - dimensionService.getDimensions(new MetaFilter()).stream() - .collect(Collectors.toMap(DimensionResp::getId, m -> m)); - appPageInfo.setList( - appDOPageInfo.getList().stream() - .map(appDO -> convert(appDO, dimensionResps, metricResps, user)) - .collect(Collectors.toList())); + Map metricResps = metricService.getMetrics(new MetaFilter()).stream() + .collect(Collectors.toMap(MetricResp::getId, m -> m)); + Map dimensionResps = dimensionService.getDimensions(new MetaFilter()) + .stream().collect(Collectors.toMap(DimensionResp::getId, m -> m)); + appPageInfo.setList(appDOPageInfo.getList().stream() + .map(appDO -> convert(appDO, dimensionResps, metricResps, user)) + .collect(Collectors.toList())); return appPageInfo; } @@ -141,12 +138,10 @@ public class AppServiceImpl extends ServiceImpl implements App @Override public AppDetailResp getApp(Integer id, User user) { AppDO appDO = getAppDO(id); - Map metricResps = - metricService.getMetrics(new MetaFilter()).stream() - .collect(Collectors.toMap(MetricResp::getId, m -> m)); - Map dimensionResps = - dimensionService.getDimensions(new MetaFilter()).stream() - .collect(Collectors.toMap(DimensionResp::getId, m -> m)); + Map metricResps = metricService.getMetrics(new MetaFilter()).stream() + .collect(Collectors.toMap(MetricResp::getId, m -> m)); + Map dimensionResps = dimensionService.getDimensions(new MetaFilter()) + .stream().collect(Collectors.toMap(DimensionResp::getId, m -> m)); checkAuth(appDO, user); return convertDetail(appDO, dimensionResps, metricResps); } @@ -179,11 +174,8 @@ public class AppServiceImpl extends ServiceImpl implements App && appDO.getOwner().contains(user.getName()); } - private AppResp convert( - AppDO appDO, - Map dimensionMap, - Map metricMap, - User user) { + private AppResp convert(AppDO appDO, Map dimensionMap, + Map metricMap, User user) { AppResp app = new AppResp(); BeanMapper.mapper(appDO, app); AppConfig appConfig = JSONObject.parseObject(appDO.getConfig(), AppConfig.class); @@ -198,8 +190,8 @@ public class AppServiceImpl extends ServiceImpl implements App return convertDetail(appDO, new HashMap<>(), new HashMap<>()); } - private AppDetailResp convertDetail( - AppDO appDO, Map dimensionMap, Map metricMap) { + private AppDetailResp convertDetail(AppDO appDO, Map dimensionMap, + Map metricMap) { AppDetailResp app = new AppDetailResp(); BeanMapper.mapper(appDO, app); AppConfig appConfig = JSONObject.parseObject(appDO.getConfig(), AppConfig.class); @@ -209,50 +201,24 @@ public class AppServiceImpl extends ServiceImpl implements App return app; } - private void fillItemName( - AppConfig appConfig, - Map dimensionMap, + private void fillItemName(AppConfig appConfig, Map dimensionMap, Map metricMap) { - appConfig - .getItems() - .forEach( - metricItem -> { - metricItem.setName( - metricMap - .getOrDefault(metricItem.getId(), new MetricResp()) - .getName()); - metricItem.setBizName( - metricMap - .getOrDefault(metricItem.getId(), new MetricResp()) - .getBizName()); - metricItem.setCreatedBy( - metricMap - .getOrDefault(metricItem.getId(), new MetricResp()) - .getCreatedBy()); - metricItem - .getRelateItems() - .forEach( - dimensionItem -> { - dimensionItem.setName( - dimensionMap - .getOrDefault( - dimensionItem.getId(), - new DimensionResp()) - .getName()); - dimensionItem.setBizName( - dimensionMap - .getOrDefault( - dimensionItem.getId(), - new DimensionResp()) - .getBizName()); - dimensionItem.setCreatedBy( - dimensionMap - .getOrDefault( - dimensionItem.getId(), - new DimensionResp()) - .getCreatedBy()); - }); - }); + appConfig.getItems().forEach(metricItem -> { + metricItem.setName( + metricMap.getOrDefault(metricItem.getId(), new MetricResp()).getName()); + metricItem.setBizName( + metricMap.getOrDefault(metricItem.getId(), new MetricResp()).getBizName()); + metricItem.setCreatedBy( + metricMap.getOrDefault(metricItem.getId(), new MetricResp()).getCreatedBy()); + metricItem.getRelateItems().forEach(dimensionItem -> { + dimensionItem.setName(dimensionMap + .getOrDefault(dimensionItem.getId(), new DimensionResp()).getName()); + dimensionItem.setBizName(dimensionMap + .getOrDefault(dimensionItem.getId(), new DimensionResp()).getBizName()); + dimensionItem.setCreatedBy(dimensionMap + .getOrDefault(dimensionItem.getId(), new DimensionResp()).getCreatedBy()); + }); + }); } private String getUniqueId() { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CanvasServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CanvasServiceImpl.java index 77010bbe3..4e8508e2b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CanvasServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CanvasServiceImpl.java @@ -27,11 +27,14 @@ import java.util.List; public class CanvasServiceImpl extends ServiceImpl implements CanvasService { - @Autowired private ModelService modelService; + @Autowired + private ModelService modelService; - @Autowired private DimensionService dimensionService; + @Autowired + private DimensionService dimensionService; - @Autowired private MetricService metricService; + @Autowired + private MetricService metricService; @Override public List getCanvasList(Long domainId) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java index 7e643603e..54b5fecd9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/CollectServiceImpl.java @@ -19,7 +19,8 @@ import java.util.List; public class CollectServiceImpl implements CollectService { public static final String type = "metric"; - @Resource private CollectMapper collectMapper; + @Resource + private CollectMapper collectMapper; @Override public Boolean collect(User user, CollectDO collectReq) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java index 8c20bdd58..3014bc788 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DataSetServiceImpl.java @@ -58,13 +58,20 @@ import java.util.stream.Collectors; public class DataSetServiceImpl extends ServiceImpl implements DataSetService { - @Autowired private DomainService domainService; + @Autowired + private DomainService domainService; - @Lazy @Autowired private DimensionService dimensionService; + @Lazy + @Autowired + private DimensionService dimensionService; - @Lazy @Autowired private MetricService metricService; + @Lazy + @Autowired + private MetricService metricService; - @Lazy @Autowired private TagMetaService tagMetaService; + @Lazy + @Autowired + private TagMetaService tagMetaService; @Override public DataSetResp save(DataSetReq dataSetReq, User user) { @@ -152,24 +159,21 @@ public class DataSetServiceImpl extends ServiceImpl List dataSetFilterByAuth = getDataSetFilterByAuth(dataSetResps, user); dataSetRespSet.addAll(dataSetFilterByAuth); if (domainId != null && domainId > 0) { - dataSetRespSet = - dataSetRespSet.stream() - .filter(modelResp -> modelResp.getDomainId().equals(domainId)) - .collect(Collectors.toSet()); + dataSetRespSet = dataSetRespSet.stream() + .filter(modelResp -> modelResp.getDomainId().equals(domainId)) + .collect(Collectors.toSet()); } - return dataSetRespSet.stream() - .sorted(Comparator.comparingLong(DataSetResp::getId)) + return dataSetRespSet.stream().sorted(Comparator.comparingLong(DataSetResp::getId)) .collect(Collectors.toList()); } private List getDataSetFilterByAuth(List dataSetResps, User user) { - return dataSetResps.stream() - .filter(dataSetResp -> checkAdminPermission(user, dataSetResp)) + return dataSetResps.stream().filter(dataSetResp -> checkAdminPermission(user, dataSetResp)) .collect(Collectors.toList()); } - private List getDataSetFilterByDomainAuth( - List dataSetResps, User user) { + private List getDataSetFilterByDomainAuth(List dataSetResps, + User user) { Set domainResps = domainService.getDomainAuthSet(user, AuthType.ADMIN); if (CollectionUtils.isEmpty(domainResps)) { return Lists.newArrayList(); @@ -190,14 +194,10 @@ public class DataSetServiceImpl extends ServiceImpl dataSetResp.setQueryConfig( JSONObject.parseObject(dataSetDO.getQueryConfig(), QueryConfig.class)); } - dataSetResp.setAdmins( - StringUtils.isBlank(dataSetDO.getAdmin()) - ? Lists.newArrayList() - : Arrays.asList(dataSetDO.getAdmin().split(","))); - dataSetResp.setAdminOrgs( - StringUtils.isBlank(dataSetDO.getAdminOrg()) - ? Lists.newArrayList() - : Arrays.asList(dataSetDO.getAdminOrg().split(","))); + dataSetResp.setAdmins(StringUtils.isBlank(dataSetDO.getAdmin()) ? Lists.newArrayList() + : Arrays.asList(dataSetDO.getAdmin().split(","))); + dataSetResp.setAdminOrgs(StringUtils.isBlank(dataSetDO.getAdminOrg()) ? Lists.newArrayList() + : Arrays.asList(dataSetDO.getAdminOrg().split(","))); dataSetResp.setTypeEnum(TypeEnums.DATASET); List dimensionItems = tagMetaService.getTagItems(dataSetResp.dimensionIds(), TagDefineType.DIMENSION); @@ -246,14 +246,10 @@ public class DataSetServiceImpl extends ServiceImpl metaFilter.setIds(dataSetIds); List dataSetList = getDataSetList(metaFilter); return dataSetList.stream() - .flatMap( - dataSetResp -> - dataSetResp.getAllModels().stream() - .map(modelId -> Pair.of(modelId, dataSetResp.getId()))) - .collect( - Collectors.groupingBy( - Pair::getLeft, - Collectors.mapping(Pair::getRight, Collectors.toList()))); + .flatMap(dataSetResp -> dataSetResp.getAllModels().stream() + .map(modelId -> Pair.of(modelId, dataSetResp.getId()))) + .collect(Collectors.groupingBy(Pair::getLeft, + Collectors.mapping(Pair::getRight, Collectors.toList()))); } @Override @@ -285,14 +281,9 @@ public class DataSetServiceImpl extends ServiceImpl } private List findDuplicates(List list, Function keyExtractor) { - return list.stream() - .collect(Collectors.groupingBy(keyExtractor, Collectors.counting())) - .entrySet() - .stream() - .filter(entry -> entry.getValue() > 1) - .map(Map.Entry::getKey) - .map(Object::toString) - .collect(Collectors.toList()); + return list.stream().collect(Collectors.groupingBy(keyExtractor, Collectors.counting())) + .entrySet().stream().filter(entry -> entry.getValue() > 1).map(Map.Entry::getKey) + .map(Object::toString).collect(Collectors.toList()); } public Long getDataSetIdFromSql(String sql, User user) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java index 60959ebfc..03654b9f5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DatabaseServiceImpl.java @@ -45,9 +45,12 @@ import java.util.stream.Collectors; public class DatabaseServiceImpl extends ServiceImpl implements DatabaseService { - @Autowired private SqlUtils sqlUtils; + @Autowired + private SqlUtils sqlUtils; - @Lazy @Autowired private ModelService datasourceService; + @Lazy + @Autowired + private ModelService datasourceService; @Override public boolean testConnect(DatabaseReq databaseReq, User user) { @@ -84,19 +87,18 @@ public class DatabaseServiceImpl extends ServiceImpl databaseResps, User user) { - databaseResps.forEach( - databaseResp -> { - if (databaseResp.getAdmins().contains(user.getName()) - || user.getName().equalsIgnoreCase(databaseResp.getCreatedBy()) - || user.isSuperAdmin()) { - databaseResp.setHasPermission(true); - databaseResp.setHasEditPermission(true); - databaseResp.setHasUsePermission(true); - } - if (databaseResp.getViewers().contains(user.getName())) { - databaseResp.setHasUsePermission(true); - } - }); + databaseResps.forEach(databaseResp -> { + if (databaseResp.getAdmins().contains(user.getName()) + || user.getName().equalsIgnoreCase(databaseResp.getCreatedBy()) + || user.isSuperAdmin()) { + databaseResp.setHasPermission(true); + databaseResp.setHasEditPermission(true); + databaseResp.setHasUsePermission(true); + } + if (databaseResp.getViewers().contains(user.getName())) { + databaseResp.setHasUsePermission(true); + } + }); } @Override @@ -135,9 +137,8 @@ public class DatabaseServiceImpl extends ServiceImpl databaseTypeList = - databaseList.stream() - .map(databaseResp -> databaseResp.getType()) - .collect(Collectors.toList()); + List databaseTypeList = databaseList.stream() + .map(databaseResp -> databaseResp.getType()).collect(Collectors.toList()); DefaultParametersBuilder defaultParametersBuilder = new DefaultParametersBuilder(); for (String dbType : databaseTypeList) { if (!parametersBuilderMap.containsKey(dbType)) { @@ -174,8 +173,7 @@ public class DatabaseServiceImpl extends ServiceImpl admins = databaseResp.getAdmins(); List viewers = databaseResp.getViewers(); - if (!admins.contains(user.getName()) - && !viewers.contains(user.getName()) + if (!admins.contains(user.getName()) && !viewers.contains(user.getName()) && !databaseResp.getCreatedBy().equalsIgnoreCase(user.getName()) && !user.isSuperAdmin()) { - String message = - String.format( - "您暂无当前数据库%s权限, 请联系数据库创建人:%s开通", - databaseResp.getName(), databaseResp.getCreatedBy()); + String message = String.format("您暂无当前数据库%s权限, 请联系数据库创建人:%s开通", databaseResp.getName(), + databaseResp.getCreatedBy()); throw new RuntimeException(message); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java index 9ae156e75..599560b06 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictTaskServiceImpl.java @@ -46,12 +46,8 @@ public class DictTaskServiceImpl implements DictTaskService { private final FileHandler fileHandler; private final DictWordService dictWordService; - public DictTaskServiceImpl( - DictRepository dictRepository, - DictUtils dictConverter, - DictUtils dictUtils, - FileHandler fileHandler, - DictWordService dictWordService) { + public DictTaskServiceImpl(DictRepository dictRepository, DictUtils dictConverter, + DictUtils dictUtils, FileHandler fileHandler, DictWordService dictWordService) { this.dictRepository = dictRepository; this.dictConverter = dictConverter; this.dictUtils = dictUtils; @@ -80,11 +76,8 @@ public class DictTaskServiceImpl implements DictTaskService { } private DictItemResp fetchDictItemResp(DictSingleTaskReq taskReq) { - DictItemFilter dictItemFilter = - DictItemFilter.builder() - .itemId(taskReq.getItemId()) - .type(taskReq.getType()) - .build(); + DictItemFilter dictItemFilter = DictItemFilter.builder().itemId(taskReq.getItemId()) + .type(taskReq.getType()).build(); List dictItemRespList = dictRepository.queryDictConf(dictItemFilter); if (!CollectionUtils.isEmpty(dictItemRespList)) { return dictItemRespList.get(0); @@ -159,14 +152,9 @@ public class DictTaskServiceImpl implements DictTaskService { @Override public PageInfo queryDictValue(DictValueReq dictValueReq, User user) { - String fileName = - String.format( - "dic_value_%d_%s_%s", - dictValueReq.getModelId(), - dictValueReq.getType().name(), - dictValueReq.getItemId()) - + Constants.DOT - + dictFileType; + String fileName = String.format("dic_value_%d_%s_%s", dictValueReq.getModelId(), + dictValueReq.getType().name(), dictValueReq.getItemId()) + Constants.DOT + + dictFileType; PageInfo dictValueRespList = fileHandler.queryDictValue(fileName, dictValueReq); return dictValueRespList; @@ -174,14 +162,9 @@ public class DictTaskServiceImpl implements DictTaskService { @Override public String queryDictFilePath(DictValueReq dictValueReq, User user) { - String fileName = - String.format( - "dic_value_%d_%s_%s", - dictValueReq.getModelId(), - dictValueReq.getType().name(), - dictValueReq.getItemId()) - + Constants.DOT - + dictFileType; + String fileName = String.format("dic_value_%d_%s_%s", dictValueReq.getModelId(), + dictValueReq.getType().name(), dictValueReq.getItemId()) + Constants.DOT + + dictFileType; return fileHandler.queryDictFilePath(fileName); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictWordService.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictWordService.java index 309422b3e..624a44bc9 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictWordService.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DictWordService.java @@ -22,8 +22,10 @@ import java.util.stream.Collectors; @Slf4j public class DictWordService { - @Autowired private SchemaService schemaService; - @Autowired private KnowledgeBaseService knowledgeBaseService; + @Autowired + private SchemaService schemaService; + @Autowired + private KnowledgeBaseService knowledgeBaseService; private List preDictWords = new ArrayList<>(); @@ -37,8 +39,8 @@ public class DictWordService { long startTime = System.currentTimeMillis(); List dictWords = getAllDictWords(); List preDictWords = getPreDictWords(); - if (org.apache.commons.collections.CollectionUtils.isEqualCollection( - dictWords, preDictWords)) { + if (org.apache.commons.collections.CollectionUtils.isEqualCollection(dictWords, + preDictWords)) { log.debug("Dictionary hasn't been reloaded."); return; } @@ -61,8 +63,8 @@ public class DictWordService { return words; } - private void addWordsByType( - DictWordType value, List metas, List natures) { + private void addWordsByType(DictWordType value, List metas, + List natures) { metas = distinct(metas); List natureList = WordBuilderFactory.get(value).getDictWords(metas); log.debug("nature type:{} , nature size:{}", value.name(), natureList.size()); @@ -87,8 +89,6 @@ public class DictWordService { return metas.stream() .collect( Collectors.toMap(SchemaElement::getId, Function.identity(), (e1, e2) -> e1)) - .values() - .stream() - .collect(Collectors.toList()); + .values().stream().collect(Collectors.toList()); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java index f5d32ccce..6b8d231a8 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DimensionServiceImpl.java @@ -74,15 +74,12 @@ public class DimensionServiceImpl extends ServiceImpl dimensionResps = getDimensions(modelId); - Map bizNameMap = - dimensionResps.stream() - .collect( - Collectors.toMap( - DimensionResp::getBizName, a -> a, (k1, k2) -> k1)); - Map nameMap = - dimensionResps.stream() - .collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1)); + Map bizNameMap = dimensionResps.stream() + .collect(Collectors.toMap(DimensionResp::getBizName, a -> a, (k1, k2) -> k1)); + Map nameMap = dimensionResps.stream() + .collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1)); List dimensionToInsert = Lists.newArrayList(); - dimensionReqs.stream() - .forEach( - dimension -> { - if (!bizNameMap.containsKey(dimension.getBizName()) - && !nameMap.containsKey(dimension.getName())) { - dimensionToInsert.add(dimension); - } else { - DimensionResp dimensionRespByBizName = - bizNameMap.get(dimension.getBizName()); - DimensionResp dimensionRespByName = - nameMap.get(dimension.getName()); - if (null != dimensionRespByBizName - && isChange(dimension, dimensionRespByBizName)) { - dimension.setId(dimensionRespByBizName.getId()); - this.updateDimension(dimension, user); - } else { - if (null != dimensionRespByName - && isChange(dimension, dimensionRespByName)) { - dimension.setId(dimensionRespByName.getId()); - this.updateDimension(dimension, user); - } - } - } - }); + dimensionReqs.stream().forEach(dimension -> { + if (!bizNameMap.containsKey(dimension.getBizName()) + && !nameMap.containsKey(dimension.getName())) { + dimensionToInsert.add(dimension); + } else { + DimensionResp dimensionRespByBizName = bizNameMap.get(dimension.getBizName()); + DimensionResp dimensionRespByName = nameMap.get(dimension.getName()); + if (null != dimensionRespByBizName && isChange(dimension, dimensionRespByBizName)) { + dimension.setId(dimensionRespByBizName.getId()); + this.updateDimension(dimension, user); + } else { + if (null != dimensionRespByName && isChange(dimension, dimensionRespByName)) { + dimension.setId(dimensionRespByName.getId()); + this.updateDimension(dimension, user); + } + } + } + }); if (CollectionUtils.isEmpty(dimensionToInsert)) { return; } List dimensionDOS = - dimensionToInsert.stream() - .peek(dimension -> dimension.createdBy(user.getName())) - .map(DimensionConverter::convert2DimensionDO) - .collect(Collectors.toList()); + dimensionToInsert.stream().peek(dimension -> dimension.createdBy(user.getName())) + .map(DimensionConverter::convert2DimensionDO).collect(Collectors.toList()); dimensionRepository.createDimensionBatch(dimensionDOS); sendEventBatch(dimensionDOS, EventType.ADD); } @@ -166,13 +151,9 @@ public class DimensionServiceImpl extends ServiceImpl { - dimensionDO.setStatus(metaBatchReq.getStatus()); - dimensionDO.setUpdatedAt(new Date()); - dimensionDO.setUpdatedBy(user.getName()); - }) - .collect(Collectors.toList()); + dimensionDOS = dimensionDOS.stream().peek(dimensionDO -> { + dimensionDO.setStatus(metaBatchReq.getStatus()); + dimensionDO.setUpdatedAt(new Date()); + dimensionDO.setUpdatedBy(user.getName()); + }).collect(Collectors.toList()); dimensionRepository.batchUpdateStatus(dimensionDOS); if (StatusEnum.OFFLINE.getCode().equals(metaBatchReq.getStatus()) || StatusEnum.DELETED.getCode().equals(metaBatchReq.getStatus())) { @@ -303,8 +280,8 @@ public class DimensionServiceImpl extends ServiceImpl filterByField( - List dimensionResps, List fields) { + private List filterByField(List dimensionResps, + List fields) { List dimensionFiltered = Lists.newArrayList(); for (DimensionResp dimensionResp : dimensionResps) { for (String field : fields) { @@ -338,26 +315,17 @@ public class DimensionServiceImpl extends ServiceImpl modelMap = modelService.getModelMap(modelFilter); List dimensionResps = Lists.newArrayList(); if (!CollectionUtils.isEmpty(dimensionDOS)) { - dimensionResps = - dimensionDOS.stream() - .map( - dimensionDO -> - DimensionConverter.convert2DimensionResp( - dimensionDO, modelMap)) - .collect(Collectors.toList()); + dimensionResps = dimensionDOS.stream().map( + dimensionDO -> DimensionConverter.convert2DimensionResp(dimensionDO, modelMap)) + .collect(Collectors.toList()); } return dimensionResps; } @Override public List mockAlias(DimensionReq dimensionReq, String mockType, User user) { - String mockAlias = - aliasGenerateHelper.generateAlias( - mockType, - dimensionReq.getName(), - dimensionReq.getBizName(), - "", - dimensionReq.getDescription()); + String mockAlias = aliasGenerateHelper.generateAlias(mockType, dimensionReq.getName(), + dimensionReq.getBizName(), "", dimensionReq.getDescription()); String ret = aliasGenerateHelper.extractJsonStringFromAiMessage(mockAlias); return JSONObject.parseObject(ret, new TypeReference>() {}); } @@ -373,13 +341,8 @@ public class DimensionServiceImpl extends ServiceImpl> resultList = semanticQueryResp.getResultList(); List valueList = new ArrayList<>(); @@ -403,11 +366,9 @@ public class DimensionServiceImpl extends ServiceImpl dimensionReqs) { Long modelId = dimensionReqs.get(0).getModelId(); List dimensionResps = getDimensions(modelId); - Map bizNameMap = - dimensionResps.stream() - .collect( - Collectors.toMap( - DimensionResp::getBizName, a -> a, (k1, k2) -> k1)); - Map nameMap = - dimensionResps.stream() - .collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1)); + Map bizNameMap = dimensionResps.stream() + .collect(Collectors.toMap(DimensionResp::getBizName, a -> a, (k1, k2) -> k1)); + Map nameMap = dimensionResps.stream() + .collect(Collectors.toMap(DimensionResp::getName, a -> a, (k1, k2) -> k1)); for (DimensionReq dimensionReq : dimensionReqs) { String forbiddenCharacters = NameCheckUtils.findForbiddenCharacters(dimensionReq.getName()); if (StringUtils.isNotBlank(forbiddenCharacters)) { - throw new InvalidArgumentException( - String.format( - "名称包含特殊字符, 请修改: %s,特殊字符: %s", - dimensionReq.getName(), forbiddenCharacters)); + throw new InvalidArgumentException(String.format("名称包含特殊字符, 请修改: %s,特殊字符: %s", + dimensionReq.getName(), forbiddenCharacters)); } if (bizNameMap.containsKey(dimensionReq.getBizName())) { DimensionResp dimensionResp = bizNameMap.get(dimensionReq.getBizName()); if (!dimensionResp.getId().equals(dimensionReq.getId())) { - throw new RuntimeException( - String.format( - "该主题域下存在相同的维度字段名:%s 创建人:%s", - dimensionReq.getBizName(), dimensionResp.getCreatedBy())); + throw new RuntimeException(String.format("该主题域下存在相同的维度字段名:%s 创建人:%s", + dimensionReq.getBizName(), dimensionResp.getCreatedBy())); } } if (nameMap.containsKey(dimensionReq.getName())) { DimensionResp dimensionResp = nameMap.get(dimensionReq.getName()); if (!dimensionResp.getId().equals(dimensionReq.getId())) { - throw new RuntimeException( - String.format( - "该主题域下存在相同的维度名:%s 创建人:%s", - dimensionReq.getName(), dimensionResp.getCreatedBy())); + throw new RuntimeException(String.format("该主题域下存在相同的维度名:%s 创建人:%s", + dimensionReq.getName(), dimensionResp.getCreatedBy())); } } } @@ -478,19 +429,12 @@ public class DimensionServiceImpl extends ServiceImpl dimensionDOS, EventType eventType) { - List dataItems = - dimensionDOS.stream() - .map( - dimensionDO -> - DataItem.builder() - .id(dimensionDO.getId() + Constants.UNDERLINE) - .name(dimensionDO.getName()) - .modelId( - dimensionDO.getModelId() - + Constants.UNDERLINE) - .type(TypeEnums.DIMENSION) - .build()) - .collect(Collectors.toList()); + List dataItems = dimensionDOS.stream() + .map(dimensionDO -> DataItem.builder().id(dimensionDO.getId() + Constants.UNDERLINE) + .name(dimensionDO.getName()) + .modelId(dimensionDO.getModelId() + Constants.UNDERLINE) + .type(TypeEnums.DIMENSION).build()) + .collect(Collectors.toList()); return new DataEvent(this, dataItems, eventType); } @@ -499,10 +443,8 @@ public class DimensionServiceImpl extends ServiceImpl getDomainList(List domainIds) { - return getDomainList().stream() - .filter(domainDO -> domainIds.contains(domainDO.getId())) + return getDomainList().stream().filter(domainDO -> domainIds.contains(domainDO.getId())) .collect(Collectors.toList()); } @@ -113,10 +110,8 @@ public class DomainServiceImpl implements DomainService { domainResp.setHasModel(true); } } - return new ArrayList<>(domainWithAuthAll) - .stream() - .sorted(Comparator.comparingLong(DomainResp::getId)) - .collect(Collectors.toList()); + return new ArrayList<>(domainWithAuthAll).stream() + .sorted(Comparator.comparingLong(DomainResp::getId)).collect(Collectors.toList()); } @Override @@ -125,19 +120,16 @@ public class DomainServiceImpl implements DomainService { Set orgIds = userService.getUserAllOrgId(user.getName()); Set domainWithAuth = Sets.newHashSet(); if (authTypeEnum.equals(AuthType.ADMIN)) { - domainWithAuth = - domainResps.stream() - .filter(domainResp -> checkAdminPermission(orgIds, user, domainResp)) - .collect(Collectors.toSet()); - return domainWithAuth.stream() - .peek(domainResp -> domainResp.setHasEditPermission(true)) + domainWithAuth = domainResps.stream() + .filter(domainResp -> checkAdminPermission(orgIds, user, domainResp)) + .collect(Collectors.toSet()); + return domainWithAuth.stream().peek(domainResp -> domainResp.setHasEditPermission(true)) .collect(Collectors.toSet()); } if (authTypeEnum.equals(AuthType.VISIBLE)) { - domainWithAuth = - domainResps.stream() - .filter(domainResp -> checkViewPermission(orgIds, user, domainResp)) - .collect(Collectors.toSet()); + domainWithAuth = domainResps.stream() + .filter(domainResp -> checkViewPermission(orgIds, user, domainResp)) + .collect(Collectors.toSet()); } return domainWithAuth; @@ -217,9 +209,8 @@ public class DomainServiceImpl implements DomainService { public Map getDomainFullPathMap() { Map domainFullPathMap = new HashMap<>(); List domainDOList = domainRepository.getDomainList(); - Map domainDOMap = - domainDOList.stream() - .collect(Collectors.toMap(DomainDO::getId, a -> a, (k1, k2) -> k1)); + Map domainDOMap = domainDOList.stream() + .collect(Collectors.toMap(DomainDO::getId, a -> a, (k1, k2) -> k1)); for (DomainDO domainDO : domainDOList) { final Long domainId = domainDO.getId(); StringBuilder fullPath = new StringBuilder(domainDO.getBizName() + "/"); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java index 1f8893025..084c9c1f5 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/DownloadServiceImpl.java @@ -69,9 +69,7 @@ public class DownloadServiceImpl implements DownloadService { private SemanticLayerService queryService; - public DownloadServiceImpl( - MetricService metricService, - DimensionService dimensionService, + public DownloadServiceImpl(MetricService metricService, DimensionService dimensionService, SemanticLayerService queryService) { this.metricService = metricService; this.dimensionService = dimensionService; @@ -79,9 +77,8 @@ public class DownloadServiceImpl implements DownloadService { } @Override - public void downloadByStruct( - DownloadMetricReq downloadMetricReq, User user, HttpServletResponse response) - throws Exception { + public void downloadByStruct(DownloadMetricReq downloadMetricReq, User user, + HttpServletResponse response) throws Exception { String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), dateFormat)); File file = FileUtils.createTmpFile(fileName); @@ -91,14 +88,10 @@ public class DownloadServiceImpl implements DownloadService { queryService.queryByReq(queryStructReq.convert(true), user); DataDownload dataDownload = buildDataDownload(queryResult, queryStructReq, downloadMetricReq.isTransform()); - EasyExcel.write(file) - .sheet("Sheet1") - .head(dataDownload.getHeaders()) + EasyExcel.write(file).sheet("Sheet1").head(dataDownload.getHeaders()) .doWrite(dataDownload.getData()); } catch (RuntimeException e) { - EasyExcel.write(file) - .sheet("Sheet1") - .head(buildErrMessageHead()) + EasyExcel.write(file).sheet("Sheet1").head(buildErrMessageHead()) .doWrite(buildErrMessageData(e.getMessage())); return; } @@ -106,9 +99,8 @@ public class DownloadServiceImpl implements DownloadService { } @Override - public void batchDownload( - BatchDownloadReq batchDownloadReq, User user, HttpServletResponse response) - throws Exception { + public void batchDownload(BatchDownloadReq batchDownloadReq, User user, + HttpServletResponse response) throws Exception { String fileName = String.format("%s_%s.xlsx", "supersonic", DateUtils.format(new Date(), dateFormat)); File file = FileUtils.createTmpFile(fileName); @@ -127,16 +119,13 @@ public class DownloadServiceImpl implements DownloadService { metaFilter.setIds(metricIds); List metricResps = metricService.getMetrics(metaFilter); Map> metricMap = getMetricMap(metricResps); - List dimensionIds = - metricResps.stream() - .map(metricResp -> metricService.getDrillDownDimension(metricResp.getId())) - .flatMap(Collection::stream) - .map(DrillDownDimension::getDimensionId) - .collect(Collectors.toList()); + List dimensionIds = metricResps.stream() + .map(metricResp -> metricService.getDrillDownDimension(metricResp.getId())) + .flatMap(Collection::stream).map(DrillDownDimension::getDimensionId) + .collect(Collectors.toList()); metaFilter.setIds(dimensionIds); - Map dimensionRespMap = - dimensionService.getDimensions(metaFilter).stream() - .collect(Collectors.toMap(DimensionResp::getId, d -> d)); + Map dimensionRespMap = dimensionService.getDimensions(metaFilter) + .stream().collect(Collectors.toMap(DimensionResp::getId, d -> d)); ExcelWriter excelWriter = EasyExcel.write(file).build(); int sheetCount = 1; for (List metrics : metricMap.values()) { @@ -152,18 +141,13 @@ public class DownloadServiceImpl implements DownloadService { QuerySqlReq querySqlReq = queryStructReq.convert(); querySqlReq.setNeedAuth(true); SemanticQueryResp queryResult = queryService.queryByReq(querySqlReq, user); - DataDownload dataDownload = - buildDataDownload( - queryResult, queryStructReq, batchDownloadReq.isTransform()); - WriteSheet writeSheet = - EasyExcel.writerSheet("Sheet" + sheetCount) - .head(dataDownload.getHeaders()) - .build(); + DataDownload dataDownload = buildDataDownload(queryResult, queryStructReq, + batchDownloadReq.isTransform()); + WriteSheet writeSheet = EasyExcel.writerSheet("Sheet" + sheetCount) + .head(dataDownload.getHeaders()).build(); excelWriter.write(dataDownload.getData(), writeSheet); } catch (RuntimeException e) { - EasyExcel.write(file) - .sheet("Sheet1") - .head(buildErrMessageHead()) + EasyExcel.write(file).sheet("Sheet1").head(buildErrMessageHead()) .doWrite(buildErrMessageData(e.getMessage())); return; } @@ -220,11 +204,8 @@ public class DownloadServiceImpl implements DownloadService { return data; } - private List> buildData( - List> headers, - Map nameMap, - List> dataTransformed, - String metricName) { + private List> buildData(List> headers, Map nameMap, + List> dataTransformed, String metricName) { List> data = Lists.newArrayList(); for (Map map : dataTransformed) { List row = Lists.newArrayList(); @@ -246,27 +227,20 @@ public class DownloadServiceImpl implements DownloadService { return data; } - private DataDownload buildDataDownload( - SemanticQueryResp queryResult, QueryStructReq queryStructReq, boolean isTransform) { + private DataDownload buildDataDownload(SemanticQueryResp queryResult, + QueryStructReq queryStructReq, boolean isTransform) { List metricColumns = queryResult.getMetricColumns(); List dimensionColumns = queryResult.getDimensionColumns(); if (isTransform && !CollectionUtils.isEmpty(metricColumns)) { QueryColumn metric = metricColumns.get(0); List groups = queryStructReq.getGroups(); List> dataTransformed = - DataTransformUtils.transform( - queryResult.getResultList(), - metric.getNameEn(), - groups, - queryStructReq.getDateInfo()); + DataTransformUtils.transform(queryResult.getResultList(), metric.getNameEn(), + groups, queryStructReq.getDateInfo()); List> headers = buildHeader(dimensionColumns, queryStructReq.getDateInfo().getDateList()); - List> data = - buildData( - headers, - getDimensionNameMap(dimensionColumns), - dataTransformed, - metric.getName()); + List> data = buildData(headers, getDimensionNameMap(dimensionColumns), + dataTransformed, metric.getName()); return DataDownload.builder().headers(headers).data(data).build(); } else { List> data = buildData(queryResult); @@ -275,19 +249,15 @@ public class DownloadServiceImpl implements DownloadService { } } - private QueryStructReq buildDownloadReq( - List dimensionResps, - MetricResp metricResp, - BatchDownloadReq batchDownloadReq) { + private QueryStructReq buildDownloadReq(List dimensionResps, + MetricResp metricResp, BatchDownloadReq batchDownloadReq) { DateConf dateConf = batchDownloadReq.getDateInfo(); Set modelIds = dimensionResps.stream().map(DimensionResp::getModelId).collect(Collectors.toSet()); modelIds.add(metricResp.getModelId()); QueryStructReq queryStructReq = new QueryStructReq(); - queryStructReq.setGroups( - dimensionResps.stream() - .map(DimensionResp::getBizName) - .collect(Collectors.toList())); + queryStructReq.setGroups(dimensionResps.stream().map(DimensionResp::getBizName) + .collect(Collectors.toList())); queryStructReq.getGroups().add(0, getTimeDimension(dateConf)); Aggregator aggregator = new Aggregator(); aggregator.setColumn(metricResp.getBizName()); @@ -325,19 +295,15 @@ public class DownloadServiceImpl implements DownloadService { .collect(Collectors.toMap(QueryColumn::getName, QueryColumn::getNameEn)); } - private List getMetricRelaDimensions( - MetricResp metricResp, Map dimensionRespMap) { - if (metricResp.getRelateDimension() == null - || CollectionUtils.isEmpty( - metricResp.getRelateDimension().getDrillDownDimensions())) { + private List getMetricRelaDimensions(MetricResp metricResp, + Map dimensionRespMap) { + if (metricResp.getRelateDimension() == null || CollectionUtils + .isEmpty(metricResp.getRelateDimension().getDrillDownDimensions())) { return Lists.newArrayList(); } - return metricResp.getRelateDimension().getDrillDownDimensions().stream() - .map( - drillDownDimension -> - dimensionRespMap.get(drillDownDimension.getDimensionId())) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + return metricResp.getRelateDimension().getDrillDownDimensions().stream().map( + drillDownDimension -> dimensionRespMap.get(drillDownDimension.getDimensionId())) + .filter(Objects::nonNull).collect(Collectors.toList()); } private void downloadFile(HttpServletResponse response, File file, String filename) { @@ -345,8 +311,7 @@ public class DownloadServiceImpl implements DownloadService { byte[] buffer = readFileToByteArray(file); response.reset(); response.setCharacterEncoding("UTF-8"); - response.addHeader( - "Content-Disposition", + response.addHeader("Content-Disposition", "attachment;filename=" + URLEncoder.encode(filename, "UTF-8")); response.addHeader("Content-Length", "" + file.length()); try (OutputStream outputStream = new BufferedOutputStream(response.getOutputStream())) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java index a4421c60f..21d0fe910 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/FlightServiceImpl.java @@ -88,10 +88,8 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS private final AuthenticationConfig authenticationConfig; private final UserService userService; - public FlightServiceImpl( - SemanticLayerService queryService, - AuthenticationConfig authenticationConfig, - UserService userService) { + public FlightServiceImpl(SemanticLayerService queryService, + AuthenticationConfig authenticationConfig, UserService userService) { this.queryService = queryService; this.authenticationConfig = authenticationConfig; @@ -104,14 +102,11 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS } @Override - public void setExecutorService( - ExecutorService executorService, Integer queue, Integer expireMinute) { + public void setExecutorService(ExecutorService executorService, Integer queue, + Integer expireMinute) { this.executorService = executorService; - this.preparedStatementCache = - CacheBuilder.newBuilder() - .maximumSize(queue) - .expireAfterWrite(expireMinute, TimeUnit.MINUTES) - .build(); + this.preparedStatementCache = CacheBuilder.newBuilder().maximumSize(queue) + .expireAfterWrite(expireMinute, TimeUnit.MINUTES).build(); } @Override @@ -120,26 +115,20 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS } @Override - public void getStreamStatement( - final TicketStatementQuery ticketStatementQuery, - final CallContext context, - final ServerStreamListener listener) { + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, + final CallContext context, final ServerStreamListener listener) { final ByteString handle = ticketStatementQuery.getStatementHandle(); log.info("getStreamStatement {} ", handle); executeQuery(handle, listener); } @Override - public FlightInfo getFlightInfoStatement( - final CommandStatementQuery request, - final CallContext context, - final FlightDescriptor descriptor) { + public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, + final CallContext context, final FlightDescriptor descriptor) { try { ByteString preparedStatementHandle = addPrepared(context, request.getQuery()); - TicketStatementQuery ticket = - TicketStatementQuery.newBuilder() - .setStatementHandle(preparedStatementHandle) - .build(); + TicketStatementQuery ticket = TicketStatementQuery.newBuilder() + .setStatementHandle(preparedStatementHandle).build(); return getFlightInfoForSchema(ticket, descriptor, null); } catch (Exception e) { log.error("getFlightInfoStatement error {}", e); @@ -148,10 +137,8 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS } @Override - public void getStreamPreparedStatement( - final CommandPreparedStatementQuery command, - final CallContext context, - final ServerStreamListener listener) { + public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, + final CallContext context, final ServerStreamListener listener) { log.info("getStreamPreparedStatement {}", command.getPreparedStatementHandle()); executeQuery(command.getPreparedStatementHandle(), listener); } @@ -160,117 +147,88 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS SemanticQueryReq semanticQueryReq = preparedStatementCache.getIfPresent(hander); if (Objects.isNull(semanticQueryReq)) { listener.error( - CallStatus.INTERNAL - .withDescription("Failed to get prepared statement: empty") + CallStatus.INTERNAL.withDescription("Failed to get prepared statement: empty") .toRuntimeException()); log.error("getStreamPreparedStatement error {}", hander); listener.completed(); return; } - executorService.submit( - () -> { - BufferAllocator rootAllocator = new RootAllocator(); - try { - Optional authOpt = - semanticQueryReq.getParams().stream() - .filter( - p -> - p.getName() - .equals( - authenticationConfig - .getTokenHttpHeaderKey())) - .findFirst(); - if (authOpt.isPresent()) { - User user = - UserHolder.findUser( - authOpt.get().getValue(), - authenticationConfig.getTokenHttpHeaderAppKey()); - SemanticQueryResp resp = - queryService.queryByReq(semanticQueryReq, user); - ResultSet resultSet = - semanticQueryRespToResultSet( - resp, semanticQueryReq.getDataSetId()); - final Schema schema = - jdbcToArrowSchema(resultSet.getMetaData(), defaultCalendar); - try (final VectorSchemaRoot vectorSchemaRoot = - VectorSchemaRoot.create(schema, rootAllocator)) { - final VectorLoader loader = new VectorLoader(vectorSchemaRoot); - listener.start(vectorSchemaRoot); - final ArrowVectorIterator iterator = - sqlToArrowVectorIterator(resultSet, rootAllocator); - while (iterator.hasNext()) { - final VectorSchemaRoot batch = iterator.next(); - if (batch.getRowCount() == 0) { - break; - } - final VectorUnloader unloader = new VectorUnloader(batch); - loader.load(unloader.getRecordBatch()); - listener.putNext(); - vectorSchemaRoot.clear(); - } - - listener.putNext(); + executorService.submit(() -> { + BufferAllocator rootAllocator = new RootAllocator(); + try { + Optional authOpt = semanticQueryReq.getParams().stream().filter( + p -> p.getName().equals(authenticationConfig.getTokenHttpHeaderKey())) + .findFirst(); + if (authOpt.isPresent()) { + User user = UserHolder.findUser(authOpt.get().getValue(), + authenticationConfig.getTokenHttpHeaderAppKey()); + SemanticQueryResp resp = queryService.queryByReq(semanticQueryReq, user); + ResultSet resultSet = + semanticQueryRespToResultSet(resp, semanticQueryReq.getDataSetId()); + final Schema schema = + jdbcToArrowSchema(resultSet.getMetaData(), defaultCalendar); + try (final VectorSchemaRoot vectorSchemaRoot = + VectorSchemaRoot.create(schema, rootAllocator)) { + final VectorLoader loader = new VectorLoader(vectorSchemaRoot); + listener.start(vectorSchemaRoot); + final ArrowVectorIterator iterator = + sqlToArrowVectorIterator(resultSet, rootAllocator); + while (iterator.hasNext()) { + final VectorSchemaRoot batch = iterator.next(); + if (batch.getRowCount() == 0) { + break; } + final VectorUnloader unloader = new VectorUnloader(batch); + loader.load(unloader.getRecordBatch()); + listener.putNext(); + vectorSchemaRoot.clear(); } - } catch (Exception e) { - listener.error( - CallStatus.INTERNAL - .withDescription( - String.format( - "Failed to get exec statement %s", - e.getMessage())) - .toRuntimeException()); - log.error("getStreamPreparedStatement error {}", hander); - } finally { - preparedStatementCache.invalidate(hander); - listener.completed(); - rootAllocator.close(); + + listener.putNext(); } - }); + } + } catch (Exception e) { + listener.error(CallStatus.INTERNAL + .withDescription( + String.format("Failed to get exec statement %s", e.getMessage())) + .toRuntimeException()); + log.error("getStreamPreparedStatement error {}", hander); + } finally { + preparedStatementCache.invalidate(hander); + listener.completed(); + rootAllocator.close(); + } + }); } @Override - public void closePreparedStatement( - final ActionClosePreparedStatementRequest request, - final CallContext context, - final StreamListener listener) { + public void closePreparedStatement(final ActionClosePreparedStatementRequest request, + final CallContext context, final StreamListener listener) { log.info("closePreparedStatement {}", request.getPreparedStatementHandle()); listener.onCompleted(); } @Override - public FlightInfo getFlightInfoPreparedStatement( - final CommandPreparedStatementQuery command, - final CallContext context, - final FlightDescriptor descriptor) { + public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, + final CallContext context, final FlightDescriptor descriptor) { return getFlightInfoForSchema(command, descriptor, null); } @Override - public void createPreparedStatement( - final ActionCreatePreparedStatementRequest request, - final CallContext context, - final StreamListener listener) { + public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, + final CallContext context, final StreamListener listener) { prepared(request, context, listener); } private ByteString addPrepared(final CallContext context, String query) throws Exception { - if (Arrays.asList(dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey).stream() - .anyMatch( - h -> - !context.getMiddleware(FlightConstants.HEADER_KEY) - .headers() - .containsKey(h))) { - throw new Exception( - String.format( - "Failed to create prepared statement: HeaderCallOption miss %s %s %s", - dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey)); + if (Arrays.asList(dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey).stream().anyMatch( + h -> !context.getMiddleware(FlightConstants.HEADER_KEY).headers().containsKey(h))) { + throw new Exception(String.format( + "Failed to create prepared statement: HeaderCallOption miss %s %s %s", + dataSetIdHeaderKey, nameHeaderKey, passwordHeaderKey)); } - Long dataSetId = - Long.valueOf( - context.getMiddleware(FlightConstants.HEADER_KEY) - .headers() - .get(dataSetIdHeaderKey)); + Long dataSetId = Long.valueOf(context.getMiddleware(FlightConstants.HEADER_KEY).headers() + .get(dataSetIdHeaderKey)); if (StringUtils.isBlank(query)) { throw new Exception("Failed to create prepared statement: query is empty"); } @@ -287,34 +245,28 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS querySqlReq.setParams( Arrays.asList(new Param(authenticationConfig.getTokenHttpHeaderKey(), auth))); preparedStatementCache.put(preparedStatementHandle, querySqlReq); - log.info( - "createPreparedStatement {} {} {} ", preparedStatementHandle, dataSetId, query); + log.info("createPreparedStatement {} {} {} ", preparedStatementHandle, dataSetId, + query); return preparedStatementHandle; } catch (Exception e) { throw e; } } - private void prepared( - final ActionCreatePreparedStatementRequest request, - final CallContext context, - final StreamListener listener) { + private void prepared(final ActionCreatePreparedStatementRequest request, + final CallContext context, final StreamListener listener) { try { ByteString preparedStatementHandle = addPrepared(context, request.getQuery()); - final ActionCreatePreparedStatementResult result = - ActionCreatePreparedStatementResult.newBuilder() - .setDatasetSchema(ByteString.EMPTY) - .setParameterSchema(ByteString.empty()) - .setPreparedStatementHandle(preparedStatementHandle) - .build(); + final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult + .newBuilder().setDatasetSchema(ByteString.EMPTY) + .setParameterSchema(ByteString.empty()) + .setPreparedStatementHandle(preparedStatementHandle).build(); listener.onNext(new Result(pack(result).toByteArray())); } catch (Exception e) { listener.onError( CallStatus.INTERNAL - .withDescription( - String.format( - "Failed to create prepared statement: %s", - e.getMessage())) + .withDescription(String.format( + "Failed to create prepared statement: %s", e.getMessage())) .toRuntimeException()); } finally { listener.onCompleted(); @@ -322,13 +274,13 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS } @Override - protected List determineEndpoints( - T t, FlightDescriptor flightDescriptor, Schema schema) { + protected List determineEndpoints(T t, + FlightDescriptor flightDescriptor, Schema schema) { throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); } - private FlightInfo getFlightInfoForSchema( - final T request, final FlightDescriptor descriptor, final Schema schema) { + private FlightInfo getFlightInfoForSchema(final T request, + final FlightDescriptor descriptor, final Schema schema) { final Ticket ticket = new Ticket(pack(request).toByteArray()); Location listenLocation = Location.forGrpcInsecure(host, port); final List endpoints = @@ -359,13 +311,9 @@ public class FlightServiceImpl extends BasicFlightSqlProducer implements FlightS for (int i = 1; i <= columnNum; i++) { String columnName = resp.getColumns().get(i - 1).getNameEn(); rowSetMetaData.setColumnName(i, columnName); - Optional> valOpt = - resp.getResultList().stream() - .filter( - r -> - r.containsKey(columnName) - && Objects.nonNull(r.get(columnName))) - .findFirst(); + Optional> valOpt = resp.getResultList().stream() + .filter(r -> r.containsKey(columnName) && Objects.nonNull(r.get(columnName))) + .findFirst(); if (valOpt.isPresent()) { int type = FlightUtils.resolveType(valOpt.get()); rowSetMetaData.setColumnType(i, type); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java index 723e7ea1b..ed021f834 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/MetricServiceImpl.java @@ -110,15 +110,10 @@ public class MetricServiceImpl extends ServiceImpl private ChatLayerService chatLayerService; - public MetricServiceImpl( - MetricRepository metricRepository, - ModelService modelService, - AliasGenerateHelper aliasGenerateHelper, - CollectService collectService, - DataSetService dataSetService, - ApplicationEventPublisher eventPublisher, - DimensionService dimensionService, - TagMetaService tagMetaService, + public MetricServiceImpl(MetricRepository metricRepository, ModelService modelService, + AliasGenerateHelper aliasGenerateHelper, CollectService collectService, + DataSetService dataSetService, ApplicationEventPublisher eventPublisher, + DimensionService dimensionService, TagMetaService tagMetaService, @Lazy ChatLayerService chatLayerService) { this.metricRepository = metricRepository; this.modelService = modelService; @@ -149,27 +144,21 @@ public class MetricServiceImpl extends ServiceImpl } Long modelId = metricReqs.get(0).getModelId(); List metricResps = getMetrics(new MetaFilter(Lists.newArrayList(modelId))); - Map bizNameMap = - metricResps.stream() - .collect(Collectors.toMap(MetricResp::getBizName, a -> a, (k1, k2) -> k1)); - Map nameMap = - metricResps.stream() - .collect(Collectors.toMap(MetricResp::getName, a -> a, (k1, k2) -> k1)); + Map bizNameMap = metricResps.stream() + .collect(Collectors.toMap(MetricResp::getBizName, a -> a, (k1, k2) -> k1)); + Map nameMap = metricResps.stream() + .collect(Collectors.toMap(MetricResp::getName, a -> a, (k1, k2) -> k1)); List metricToInsert = metricReqs.stream() - .filter( - metric -> - !bizNameMap.containsKey(metric.getBizName()) - && !nameMap.containsKey(metric.getName())) + .filter(metric -> !bizNameMap.containsKey(metric.getBizName()) + && !nameMap.containsKey(metric.getName())) .collect(Collectors.toList()); if (CollectionUtils.isEmpty(metricToInsert)) { return; } List metricDOS = - metricToInsert.stream() - .peek(metric -> metric.createdBy(user.getName())) - .map(MetricConverter::convert2MetricDO) - .collect(Collectors.toList()); + metricToInsert.stream().peek(metric -> metric.createdBy(user.getName())) + .map(MetricConverter::convert2MetricDO).collect(Collectors.toList()); metricRepository.createMetricBatch(metricDOS); sendEventBatch(metricDOS, EventType.ADD); } @@ -201,15 +190,11 @@ public class MetricServiceImpl extends ServiceImpl if (CollectionUtils.isEmpty(metricDOS)) { return; } - metricDOS = - metricDOS.stream() - .peek( - metricDO -> { - metricDO.setStatus(metaBatchReq.getStatus()); - metricDO.setUpdatedAt(new Date()); - metricDO.setUpdatedBy(user.getName()); - }) - .collect(Collectors.toList()); + metricDOS = metricDOS.stream().peek(metricDO -> { + metricDO.setStatus(metaBatchReq.getStatus()); + metricDO.setUpdatedAt(new Date()); + metricDO.setUpdatedBy(user.getName()); + }).collect(Collectors.toList()); metricRepository.batchUpdateStatus(metricDOS); if (StatusEnum.OFFLINE.getCode().equals(metaBatchReq.getStatus()) || StatusEnum.DELETED.getCode().equals(metaBatchReq.getStatus())) { @@ -314,20 +299,14 @@ public class MetricServiceImpl extends ServiceImpl return metricRespPageInfo; } Map result = - dataSetMapInfoMap.values().stream() - .map(DataSetMapInfo::getMapFields) - .filter(Objects::nonNull) - .flatMap(Collection::stream) - .filter( - schemaElementMatch -> - SchemaElementType.METRIC.equals( - schemaElementMatch.getElement().getType())) - .collect( - Collectors.toMap( - schemaElementMatch -> - schemaElementMatch.getElement().getId(), - SchemaElementMatch::getSimilarity, - (existingValue, newValue) -> existingValue)); + dataSetMapInfoMap.values().stream().map(DataSetMapInfo::getMapFields) + .filter(Objects::nonNull).flatMap(Collection::stream) + .filter(schemaElementMatch -> SchemaElementType.METRIC + .equals(schemaElementMatch.getElement().getType())) + .collect(Collectors.toMap( + schemaElementMatch -> schemaElementMatch.getElement().getId(), + SchemaElementMatch::getSimilarity, + (existingValue, newValue) -> existingValue)); List metricIds = new ArrayList<>(result.keySet()); if (CollectionUtils.isEmpty(result.keySet())) { return metricRespPageInfo; @@ -434,20 +413,13 @@ public class MetricServiceImpl extends ServiceImpl return new ArrayList<>(metricRespFiltered); } - private boolean filterByField( - List metricResps, - MetricResp metricResp, - List fields, - Set metricRespFiltered) { + private boolean filterByField(List metricResps, MetricResp metricResp, + List fields, Set metricRespFiltered) { if (MetricDefineType.METRIC.equals(metricResp.getMetricDefineType())) { - List ids = - metricResp.getMetricDefineByMetricParams().getMetrics().stream() - .map(MetricParam::getId) - .collect(Collectors.toList()); - List metricById = - metricResps.stream() - .filter(metric -> ids.contains(metric.getId())) - .collect(Collectors.toList()); + List ids = metricResp.getMetricDefineByMetricParams().getMetrics().stream() + .map(MetricParam::getId).collect(Collectors.toList()); + List metricById = metricResps.stream() + .filter(metric -> ids.contains(metric.getId())).collect(Collectors.toList()); for (MetricResp metric : metricById) { if (filterByField(metricResps, metric, fields, metricRespFiltered)) { metricRespFiltered.add(metricResp); @@ -461,12 +433,10 @@ public class MetricServiceImpl extends ServiceImpl } } else if (MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) { List measures = metricResp.getMetricDefineByMeasureParams().getMeasures(); - List fieldNameDepended = - measures.stream() - .map(MeasureParam::getBizName) - // measure bizName = model bizName_fieldName - .map(name -> name.replaceFirst(metricResp.getModelBizName() + "_", "")) - .collect(Collectors.toList()); + List fieldNameDepended = measures.stream().map(MeasureParam::getBizName) + // measure bizName = model bizName_fieldName + .map(name -> name.replaceFirst(metricResp.getModelBizName() + "_", "")) + .collect(Collectors.toList()); if (fields.stream().anyMatch(fieldNameDepended::contains)) { metricRespFiltered.add(metricResp); return true; @@ -480,12 +450,9 @@ public class MetricServiceImpl extends ServiceImpl MetricFilter metricFilter = new MetricFilter(); metricFilter.setModelIds(Lists.newArrayList(modelId)); List metricResps = getMetrics(metricFilter); - return metricResps.stream() - .filter( - metricResp -> - MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) - || MetricDefineType.MEASURE.equals( - metricResp.getMetricDefineType())) + return metricResps.stream().filter( + metricResp -> MetricDefineType.FIELD.equals(metricResp.getMetricDefineType()) + || MetricDefineType.MEASURE.equals(metricResp.getMetricDefineType())) .collect(Collectors.toList()); } @@ -546,19 +513,10 @@ public class MetricServiceImpl extends ServiceImpl @Override public List mockAlias(MetricBaseReq metricReq, String mockType, User user) { - String mockAlias = - aliasGenerateHelper.generateAlias( - mockType, - metricReq.getName(), - metricReq.getBizName(), - "", - metricReq.getDescription()); - String ret = - mockAlias - .replaceAll("`", "") - .replace("json", "") - .replace("\n", "") - .replace(" ", ""); + String mockAlias = aliasGenerateHelper.generateAlias(mockType, metricReq.getName(), + metricReq.getBizName(), "", metricReq.getDescription()); + String ret = mockAlias.replaceAll("`", "").replace("json", "").replace("\n", "") + .replace(" ", ""); return JSONObject.parseObject(ret, new TypeReference>() {}); } @@ -568,8 +526,7 @@ public class MetricServiceImpl extends ServiceImpl if (CollectionUtils.isEmpty(metricResps)) { return new HashSet<>(); } - return metricResps.stream() - .flatMap(metricResp -> metricResp.getClassifications().stream()) + return metricResps.stream().flatMap(metricResp -> metricResp.getClassifications().stream()) .collect(Collectors.toSet()); } @@ -580,11 +537,10 @@ public class MetricServiceImpl extends ServiceImpl if (metricResp == null) { return drillDownDimensions; } - if (metricResp.getRelateDimension() != null - && !CollectionUtils.isEmpty( - metricResp.getRelateDimension().getDrillDownDimensions())) { - for (DrillDownDimension drillDownDimension : - metricResp.getRelateDimension().getDrillDownDimensions()) { + if (metricResp.getRelateDimension() != null && !CollectionUtils + .isEmpty(metricResp.getRelateDimension().getDrillDownDimensions())) { + for (DrillDownDimension drillDownDimension : metricResp.getRelateDimension() + .getDrillDownDimensions()) { if (drillDownDimension.isInheritedFromModel() && !drillDownDimension.isNecessary()) { continue; @@ -597,10 +553,8 @@ public class MetricServiceImpl extends ServiceImpl return drillDownDimensions; } for (DrillDownDimension drillDownDimension : modelResp.getDrillDownDimensions()) { - if (!drillDownDimensions.stream() - .map(DrillDownDimension::getDimensionId) - .collect(Collectors.toList()) - .contains(drillDownDimension.getDimensionId())) { + if (!drillDownDimensions.stream().map(DrillDownDimension::getDimensionId) + .collect(Collectors.toList()).contains(drillDownDimension.getDimensionId())) { drillDownDimension.setInheritedFromModel(true); drillDownDimensions.add(drillDownDimension); } @@ -639,29 +593,23 @@ public class MetricServiceImpl extends ServiceImpl MetaFilter metaFilter = new MetaFilter(); metaFilter.setModelIds(Lists.newArrayList(modelId)); List metricResps = getMetrics(metaFilter); - Map bizNameMap = - metricResps.stream() - .collect(Collectors.toMap(MetricResp::getBizName, a -> a, (k1, k2) -> k1)); - Map nameMap = - metricResps.stream() - .collect(Collectors.toMap(MetricResp::getName, a -> a, (k1, k2) -> k1)); + Map bizNameMap = metricResps.stream() + .collect(Collectors.toMap(MetricResp::getBizName, a -> a, (k1, k2) -> k1)); + Map nameMap = metricResps.stream() + .collect(Collectors.toMap(MetricResp::getName, a -> a, (k1, k2) -> k1)); for (MetricBaseReq metricReq : metricReqs) { if (bizNameMap.containsKey(metricReq.getBizName())) { MetricResp metricResp = bizNameMap.get(metricReq.getBizName()); if (!metricResp.getId().equals(metricReq.getId())) { - throw new RuntimeException( - String.format( - "该模型下存在相同的指标字段名:%s 创建人:%s", - metricReq.getBizName(), metricResp.getCreatedBy())); + throw new RuntimeException(String.format("该模型下存在相同的指标字段名:%s 创建人:%s", + metricReq.getBizName(), metricResp.getCreatedBy())); } } if (nameMap.containsKey(metricReq.getName())) { MetricResp metricResp = nameMap.get(metricReq.getName()); if (!metricResp.getId().equals(metricReq.getId())) { - throw new RuntimeException( - String.format( - "该模型下存在相同的指标名:%s 创建人:%s", - metricReq.getName(), metricResp.getCreatedBy())); + throw new RuntimeException(String.format("该模型下存在相同的指标名:%s 创建人:%s", + metricReq.getName(), metricResp.getCreatedBy())); } } } @@ -678,13 +626,9 @@ public class MetricServiceImpl extends ServiceImpl ModelFilter modelFilter = new ModelFilter(false, modelIds); Map modelMap = modelService.getModelMap(modelFilter); if (!CollectionUtils.isEmpty(metricDOS)) { - metricResps = - metricDOS.stream() - .map( - metricDO -> - MetricConverter.convert2MetricResp( - metricDO, modelMap, collect)) - .collect(Collectors.toList()); + metricResps = metricDOS.stream().map( + metricDO -> MetricConverter.convert2MetricResp(metricDO, modelMap, collect)) + .collect(Collectors.toList()); } return metricResps; } @@ -729,19 +673,15 @@ public class MetricServiceImpl extends ServiceImpl MetricResp metricResp = MetricConverter.convert2MetricResp(metricDO, new HashMap<>(), Lists.newArrayList()); fillDefaultAgg(metricResp); - return DataItem.builder() - .id(metricDO.getId() + Constants.UNDERLINE) - .name(metricDO.getName()) - .bizName(metricDO.getBizName()) - .modelId(metricDO.getModelId() + Constants.UNDERLINE) - .type(TypeEnums.METRIC) - .defaultAgg(metricResp.getDefaultAgg()) - .build(); + return DataItem.builder().id(metricDO.getId() + Constants.UNDERLINE) + .name(metricDO.getName()).bizName(metricDO.getBizName()) + .modelId(metricDO.getModelId() + Constants.UNDERLINE).type(TypeEnums.METRIC) + .defaultAgg(metricResp.getDefaultAgg()).build(); } @Override - public void batchFillMetricDefaultAgg( - List metricResps, List modelResps) { + public void batchFillMetricDefaultAgg(List metricResps, + List modelResps) { Map modelRespMap = modelResps.stream().collect(Collectors.toMap(ModelResp::getId, m -> m)); for (MetricResp metricResp : metricResps) { @@ -762,9 +702,8 @@ public class MetricServiceImpl extends ServiceImpl } private String getDefaultAgg(MetricResp metricResp, ModelResp modelResp) { - if (modelResp == null - || (Objects.nonNull(metricResp.getDefaultAgg()) - && !metricResp.getDefaultAgg().isEmpty())) { + if (modelResp == null || (Objects.nonNull(metricResp.getDefaultAgg()) + && !metricResp.getDefaultAgg().isEmpty())) { return metricResp.getDefaultAgg(); } // FIELD define will get from expr @@ -824,20 +763,12 @@ public class MetricServiceImpl extends ServiceImpl queryMetricReq.setDateInfo(null); } // 4. set groups - List dimensionBizNames = - dimensionResps.stream() - .filter(entry -> modelCluster.getModelIds().contains(entry.getModelId())) - .filter( - entry -> - queryMetricReq.getDimensionNames().contains(entry.getName()) - || queryMetricReq - .getDimensionNames() - .contains(entry.getBizName()) - || queryMetricReq - .getDimensionIds() - .contains(entry.getId())) - .map(SchemaItem::getBizName) - .collect(Collectors.toList()); + List dimensionBizNames = dimensionResps.stream() + .filter(entry -> modelCluster.getModelIds().contains(entry.getModelId())) + .filter(entry -> queryMetricReq.getDimensionNames().contains(entry.getName()) + || queryMetricReq.getDimensionNames().contains(entry.getBizName()) + || queryMetricReq.getDimensionIds().contains(entry.getId())) + .map(SchemaItem::getBizName).collect(Collectors.toList()); QueryStructReq queryStructReq = new QueryStructReq(); DateConf dateInfo = queryMetricReq.getDateInfo(); @@ -848,11 +779,9 @@ public class MetricServiceImpl extends ServiceImpl queryStructReq.getGroups().addAll(dimensionBizNames); } // 5. set aggregators - List metricBizNames = - metricResps.stream() - .filter(entry -> modelCluster.getModelIds().contains(entry.getModelId())) - .map(SchemaItem::getBizName) - .collect(Collectors.toList()); + List metricBizNames = metricResps.stream() + .filter(entry -> modelCluster.getModelIds().contains(entry.getModelId())) + .map(SchemaItem::getBizName).collect(Collectors.toList()); if (CollectionUtils.isEmpty(metricBizNames)) { throw new IllegalArgumentException( "Invalid input parameters, unable to obtain valid metrics"); @@ -898,18 +827,14 @@ public class MetricServiceImpl extends ServiceImpl } } } - String keyWithMaxSize = - modelClusterToMatchCount.entrySet().stream() - .max(Comparator.comparingInt(entry -> entry.getValue().size())) - .map(Map.Entry::getKey) - .orElse(null); + String keyWithMaxSize = modelClusterToMatchCount.entrySet().stream() + .max(Comparator.comparingInt(entry -> entry.getValue().size())) + .map(Map.Entry::getKey).orElse(null); return modelClusterMap.get(keyWithMaxSize); } - private Set getModelIds( - Set modelIdsByDomainId, - List metricResps, + private Set getModelIds(Set modelIdsByDomainId, List metricResps, List dimensionResps) { Set result = new HashSet<>(); if (org.apache.commons.collections.CollectionUtils.isNotEmpty(modelIdsByDomainId)) { @@ -920,10 +845,8 @@ public class MetricServiceImpl extends ServiceImpl metricResps.stream().map(entry -> entry.getModelId()).collect(Collectors.toSet()); result.addAll(metricModelIds); - Set dimensionModelIds = - dimensionResps.stream() - .map(entry -> entry.getModelId()) - .collect(Collectors.toSet()); + Set dimensionModelIds = dimensionResps.stream().map(entry -> entry.getModelId()) + .collect(Collectors.toSet()); result.addAll(dimensionModelIds); return result; } @@ -942,9 +865,8 @@ public class MetricServiceImpl extends ServiceImpl } private Set getModelIdsByDomainId(QueryMetricReq queryMetricReq) { - List modelResps = - modelService.getAllModelByDomainIds( - Collections.singletonList(queryMetricReq.getDomainId())); + List modelResps = modelService + .getAllModelByDomainIds(Collections.singletonList(queryMetricReq.getDomainId())); return modelResps.stream().map(ModelResp::getId).collect(Collectors.toSet()); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelRelaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelRelaServiceImpl.java index 5b7ee2660..a332e7afc 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelRelaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/ModelRelaServiceImpl.java @@ -48,9 +48,7 @@ public class ModelRelaServiceImpl extends ServiceImpl metricResps = metricService.getMetrics(metaFilter); List dimensionResps = dimensionService.getDimensions(metaFilter); - return UnAvailableItemResp.builder() - .dimensionResps(dimensionResps) - .metricResps(metricResps) + return UnAvailableItemResp.builder().dimensionResps(dimensionResps).metricResps(metricResps) .build(); } @@ -204,9 +197,8 @@ public class ModelServiceImpl implements ModelService { private void checkParams(ModelReq modelReq) { String forbiddenCharacters = NameCheckUtils.findForbiddenCharacters(modelReq.getName()); if (StringUtils.isNotBlank(forbiddenCharacters)) { - String message = - String.format( - "模型名称[%s]包含特殊字符(%s), 请修改", modelReq.getName(), forbiddenCharacters); + String message = String.format("模型名称[%s]包含特殊字符(%s), 请修改", modelReq.getName(), + forbiddenCharacters); throw new InvalidArgumentException(message); } @@ -228,10 +220,8 @@ public class ModelServiceImpl implements ModelService { NameCheckUtils.findForbiddenCharacters(measure.getName()); if (StringUtils.isNotBlank(measure.getName()) && StringUtils.isNotBlank(measureForbiddenCharacters)) { - String message = - String.format( - "度量[%s]包含特殊字符(%s), 请修改", - measure.getName(), measureForbiddenCharacters); + String message = String.format("度量[%s]包含特殊字符(%s), 请修改", measure.getName(), + measureForbiddenCharacters); throw new InvalidArgumentException(message); } } @@ -239,9 +229,8 @@ public class ModelServiceImpl implements ModelService { String dimForbiddenCharacters = NameCheckUtils.findForbiddenCharacters(dim.getName()); if (StringUtils.isNotBlank(dim.getName()) && StringUtils.isNotBlank(dimForbiddenCharacters)) { - String message = - String.format( - "维度[%s]包含特殊字符(%s), 请修改", dim.getName(), dimForbiddenCharacters); + String message = String.format("维度[%s]包含特殊字符(%s), 请修改", dim.getName(), + dimForbiddenCharacters); throw new InvalidArgumentException(message); } } @@ -250,10 +239,8 @@ public class ModelServiceImpl implements ModelService { NameCheckUtils.findForbiddenCharacters(identify.getName()); if (StringUtils.isNotBlank(identify.getName()) && StringUtils.isNotBlank(identifyForbiddenCharacters)) { - String message = - String.format( - "主键/外键[%s]包含特殊字符(%s), 请修改", - identify.getName(), identifyForbiddenCharacters); + String message = String.format("主键/外键[%s]包含特殊字符(%s), 请修改", identify.getName(), + identifyForbiddenCharacters); throw new InvalidArgumentException(message); } } @@ -304,26 +291,21 @@ public class ModelServiceImpl implements ModelService { List modelRespsAuthInheritDomain = getModelRespAuthInheritDomain(user, domainId, authType); modelRespSet.addAll(modelRespsAuthInheritDomain); - return modelRespSet.stream() - .sorted(Comparator.comparingLong(ModelResp::getId)) + return modelRespSet.stream().sorted(Comparator.comparingLong(ModelResp::getId)) .collect(Collectors.toList()); } - public List getModelRespAuthInheritDomain( - User user, Long domainId, AuthType authType) { + public List getModelRespAuthInheritDomain(User user, Long domainId, + AuthType authType) { List domainIds = - domainService.getDomainAuthSet(user, authType).stream() - .filter( - domainResp -> { - if (domainId == null) { - return true; - } else { - return domainId.equals(domainResp.getId()) - || domainId.equals(domainResp.getParentId()); - } - }) - .map(DomainResp::getId) - .collect(Collectors.toList()); + domainService.getDomainAuthSet(user, authType).stream().filter(domainResp -> { + if (domainId == null) { + return true; + } else { + return domainId.equals(domainResp.getId()) + || domainId.equals(domainResp.getParentId()); + } + }).map(DomainResp::getId).collect(Collectors.toList()); if (CollectionUtils.isEmpty(domainIds)) { return Lists.newArrayList(); } @@ -342,16 +324,14 @@ public class ModelServiceImpl implements ModelService { Set orgIds = userService.getUserAllOrgId(user.getName()); List modelWithAuth = Lists.newArrayList(); if (authTypeEnum.equals(AuthType.ADMIN)) { - modelWithAuth = - modelResps.stream() - .filter(modelResp -> checkAdminPermission(orgIds, user, modelResp)) - .collect(Collectors.toList()); + modelWithAuth = modelResps.stream() + .filter(modelResp -> checkAdminPermission(orgIds, user, modelResp)) + .collect(Collectors.toList()); } if (authTypeEnum.equals(AuthType.VISIBLE)) { - modelWithAuth = - modelResps.stream() - .filter(domainResp -> checkDataSetPermission(orgIds, user, domainResp)) - .collect(Collectors.toList()); + modelWithAuth = modelResps.stream() + .filter(domainResp -> checkDataSetPermission(orgIds, user, domainResp)) + .collect(Collectors.toList()); } return modelWithAuth; } @@ -368,8 +348,7 @@ public class ModelServiceImpl implements ModelService { if (CollectionUtils.isEmpty(modelResps)) { return modelResps; } - return modelResps.stream() - .filter(modelResp -> domainIds.contains(modelResp.getDomainId())) + return modelResps.stream().filter(modelResp -> domainIds.contains(modelResp.getDomainId())) .collect(Collectors.toList()); } @@ -434,35 +413,23 @@ public class ModelServiceImpl implements ModelService { if (CollectionUtils.isEmpty(modelDOS)) { return; } - modelDOS = - modelDOS.stream() - .peek( - modelDO -> { - modelDO.setStatus(metaBatchReq.getStatus()); - modelDO.setUpdatedAt(new Date()); - modelDO.setUpdatedBy(user.getName()); - if (StatusEnum.OFFLINE - .getCode() - .equals(metaBatchReq.getStatus()) - || StatusEnum.DELETED - .getCode() - .equals(metaBatchReq.getStatus())) { - metricService.sendMetricEventBatch( - Lists.newArrayList(modelDO.getId()), - EventType.DELETE); - dimensionService.sendDimensionEventBatch( - Lists.newArrayList(modelDO.getId()), - EventType.DELETE); - } else if (StatusEnum.ONLINE - .getCode() - .equals(metaBatchReq.getStatus())) { - metricService.sendMetricEventBatch( - Lists.newArrayList(modelDO.getId()), EventType.ADD); - dimensionService.sendDimensionEventBatch( - Lists.newArrayList(modelDO.getId()), EventType.ADD); - } - }) - .collect(Collectors.toList()); + modelDOS = modelDOS.stream().peek(modelDO -> { + modelDO.setStatus(metaBatchReq.getStatus()); + modelDO.setUpdatedAt(new Date()); + modelDO.setUpdatedBy(user.getName()); + if (StatusEnum.OFFLINE.getCode().equals(metaBatchReq.getStatus()) + || StatusEnum.DELETED.getCode().equals(metaBatchReq.getStatus())) { + metricService.sendMetricEventBatch(Lists.newArrayList(modelDO.getId()), + EventType.DELETE); + dimensionService.sendDimensionEventBatch(Lists.newArrayList(modelDO.getId()), + EventType.DELETE); + } else if (StatusEnum.ONLINE.getCode().equals(metaBatchReq.getStatus())) { + metricService.sendMetricEventBatch(Lists.newArrayList(modelDO.getId()), + EventType.ADD); + dimensionService.sendDimensionEventBatch(Lists.newArrayList(modelDO.getId()), + EventType.ADD); + } + }).collect(Collectors.toList()); modelRepository.batchUpdate(modelDOS); } @@ -472,14 +439,13 @@ public class ModelServiceImpl implements ModelService { private List convert(List dateInfoDOList) { List dateInfoCommendList = new ArrayList<>(); - dateInfoDOList.forEach( - dateInfoDO -> { - DateInfoReq dateInfoCommend = new DateInfoReq(); - BeanUtils.copyProperties(dateInfoDO, dateInfoCommend); - dateInfoCommend.setUnavailableDateList( - JsonUtil.toList(dateInfoDO.getUnavailableDateList(), String.class)); - dateInfoCommendList.add(dateInfoCommend); - }); + dateInfoDOList.forEach(dateInfoDO -> { + DateInfoReq dateInfoCommend = new DateInfoReq(); + BeanUtils.copyProperties(dateInfoDO, dateInfoCommend); + dateInfoCommend.setUnavailableDateList( + JsonUtil.toList(dateInfoDO.getUnavailableDateList(), String.class)); + dateInfoCommendList.add(dateInfoCommend); + }); return dateInfoCommendList; } @@ -504,8 +470,8 @@ public class ModelServiceImpl implements ModelService { return false; } - public static boolean checkDataSetPermission( - Set orgIds, User user, ModelResp modelResp) { + public static boolean checkDataSetPermission(Set orgIds, User user, + ModelResp modelResp) { if (checkAdminPermission(orgIds, user, modelResp)) { return true; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryRuleServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryRuleServiceImpl.java index f12ebf1d6..7b70dcc76 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryRuleServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/QueryRuleServiceImpl.java @@ -26,8 +26,8 @@ public class QueryRuleServiceImpl implements QueryRuleService { private final QueryRuleRepository queryRuleRepository; private final DataSetService dataSetService; - public QueryRuleServiceImpl( - QueryRuleRepository queryRuleRepository, DataSetService dataSetService) { + public QueryRuleServiceImpl(QueryRuleRepository queryRuleRepository, + DataSetService dataSetService) { this.queryRuleRepository = queryRuleRepository; this.dataSetService = dataSetService; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java index aa8093a5f..c176f1a65 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/RetrieveServiceImpl.java @@ -48,13 +48,17 @@ public class RetrieveServiceImpl implements RetrieveService { private static final int RESULT_SIZE = 10; - @Autowired private DataSetService dataSetService; + @Autowired + private DataSetService dataSetService; - @Autowired private SchemaService schemaService; + @Autowired + private SchemaService schemaService; - @Autowired private KnowledgeBaseService knowledgeBaseService; + @Autowired + private KnowledgeBaseService knowledgeBaseService; - @Autowired private SearchMatchStrategy searchMatchStrategy; + @Autowired + private SearchMatchStrategy searchMatchStrategy; @Override public List retrieve(QueryNLReq queryNLReq) { @@ -65,9 +69,8 @@ public class RetrieveServiceImpl implements RetrieveService { schemaService.getSemanticSchema(queryNLReq.getDataSetIds()); List metricsDb = semanticSchemaDb.getMetrics(); final Map dataSetIdToName = semanticSchemaDb.getDataSetIdToName(); - Map> modelIdToDataSetIds = - dataSetService.getModelIdToDataSetIds( - new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser()); + Map> modelIdToDataSetIds = dataSetService.getModelIdToDataSetIds( + new ArrayList<>(dataSetIdToName.keySet()), User.getFakeUser()); // 2.detect by segment List originals = knowledgeBaseService.getTerms(queryText, modelIdToDataSetIds); log.debug("hanlp parse result: {}", originals); @@ -86,14 +89,9 @@ public class RetrieveServiceImpl implements RetrieveService { Optional>> mostSimilarSearchResult = regTextMap.entrySet().stream() .filter(entry -> CollectionUtils.isNotEmpty(entry.getValue())) - .reduce( - (entry1, entry2) -> - entry1.getKey().getDetectSegment().length() - >= entry2.getKey() - .getDetectSegment() - .length() - ? entry1 - : entry2); + .reduce((entry1, entry2) -> entry1.getKey().getDetectSegment() + .length() >= entry2.getKey().getDetectSegment().length() ? entry1 + : entry2); // 4.optimize the results after the query if (!mostSimilarSearchResult.isPresent()) { @@ -108,12 +106,8 @@ public class RetrieveServiceImpl implements RetrieveService { List possibleDataSets = getPossibleDataSets(queryNLReq, originals, dataSetIds); // 5.1 priority dimension metric - boolean existMetricAndDimension = - searchMetricAndDimension( - new HashSet<>(possibleDataSets), - dataSetIdToName, - searchTextEntry, - searchResults); + boolean existMetricAndDimension = searchMetricAndDimension(new HashSet<>(possibleDataSets), + dataSetIdToName, searchTextEntry, searchResults); // 5.2 process based on dimension values MatchText matchText = searchTextEntry.getKey(); @@ -123,24 +117,17 @@ public class RetrieveServiceImpl implements RetrieveService { for (Map.Entry natureToNameEntry : natureToNameMap.entrySet()) { - Set searchResultSet = - searchDimensionValue( - metricsDb, - dataSetIdToName, - dataSetInfoStat.getMetricDataSetCount(), - existMetricAndDimension, - matchText, - natureToNameMap, - natureToNameEntry, - queryNLReq.getQueryFilters()); + Set searchResultSet = searchDimensionValue(metricsDb, dataSetIdToName, + dataSetInfoStat.getMetricDataSetCount(), existMetricAndDimension, matchText, + natureToNameMap, natureToNameEntry, queryNLReq.getQueryFilters()); searchResults.addAll(searchResultSet); } return searchResults.stream().limit(RESULT_SIZE).collect(Collectors.toList()); } - private List getPossibleDataSets( - QueryNLReq queryCtx, List originals, Set dataSetIds) { + private List getPossibleDataSets(QueryNLReq queryCtx, List originals, + Set dataSetIds) { if (CollectionUtils.isNotEmpty(dataSetIds)) { return new ArrayList<>(dataSetIds); } @@ -155,15 +142,10 @@ public class RetrieveServiceImpl implements RetrieveService { return possibleDataSets; } - private Set searchDimensionValue( - List metricsDb, - Map modelToName, - long metricModelCount, - boolean existMetricAndDimension, - MatchText matchText, - Map natureToNameMap, - Map.Entry natureToNameEntry, - QueryFilters queryFilters) { + private Set searchDimensionValue(List metricsDb, + Map modelToName, long metricModelCount, boolean existMetricAndDimension, + MatchText matchText, Map natureToNameMap, + Map.Entry natureToNameEntry, QueryFilters queryFilters) { Set searchResults = new LinkedHashSet(); String nature = natureToNameEntry.getKey(); @@ -175,15 +157,10 @@ public class RetrieveServiceImpl implements RetrieveService { if (SchemaElementType.ENTITY.equals(schemaElementType)) { return searchResults; } - // If there are no metric/dimension, complete the metric information - SearchResult searchResult = - SearchResult.builder() - .modelId(modelId) - .modelName(modelToName.get(modelId)) - .recommend(matchText.getRegText() + wordName) - .schemaElementType(schemaElementType) - .subRecommend(wordName) - .build(); + // If there are no metric/dimension, complete the metric information + SearchResult searchResult = SearchResult.builder().modelId(modelId) + .modelName(modelToName.get(modelId)).recommend(matchText.getRegText() + wordName) + .schemaElementType(schemaElementType).subRecommend(wordName).build(); if (metricModelCount <= 0 && !existMetricAndDimension) { if (filterByQueryFilter(wordName, queryFilters)) { @@ -191,24 +168,15 @@ public class RetrieveServiceImpl implements RetrieveService { } searchResults.add(searchResult); int metricSize = getMetricSize(natureToNameMap); - List metrics = - filerMetricsByModel(metricsDb, modelId, metricSize * 3).stream() - .limit(metricSize) - .collect(Collectors.toList()); + List metrics = filerMetricsByModel(metricsDb, modelId, metricSize * 3).stream() + .limit(metricSize).collect(Collectors.toList()); for (String metric : metrics) { - SearchResult result = - SearchResult.builder() - .modelId(modelId) - .modelName(modelToName.get(modelId)) - .recommend( - matchText.getRegText() - + wordName - + DictWordType.SPACE - + metric) - .subRecommend(wordName + DictWordType.SPACE + metric) - .isComplete(false) - .build(); + SearchResult result = SearchResult.builder().modelId(modelId) + .modelName(modelToName.get(modelId)) + .recommend(matchText.getRegText() + wordName + DictWordType.SPACE + metric) + .subRecommend(wordName + DictWordType.SPACE + metric).isComplete(false) + .build(); searchResults.add(result); } } else { @@ -238,22 +206,19 @@ public class RetrieveServiceImpl implements RetrieveService { return true; } - protected List filerMetricsByModel( - List metricsDb, Long model, int metricSize) { + protected List filerMetricsByModel(List metricsDb, Long model, + int metricSize) { if (CollectionUtils.isEmpty(metricsDb)) { return Lists.newArrayList(); } return metricsDb.stream() .filter(mapDO -> Objects.nonNull(mapDO) && model.equals(mapDO.getDataSetId())) .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .flatMap( - entry -> { - List result = new ArrayList<>(); - result.add(entry.getName()); - return result.stream(); - }) - .limit(metricSize) - .collect(Collectors.toList()); + .flatMap(entry -> { + List result = new ArrayList<>(); + result.add(entry.getName()); + return result.stream(); + }).limit(metricSize).collect(Collectors.toList()); } /** @@ -267,35 +232,23 @@ public class RetrieveServiceImpl implements RetrieveService { Set possibleModels) { List recommendValues = recommendTextListEntry.getValue(); return recommendValues.stream() - .flatMap( - entry -> - entry.getNatures().stream() - .filter( - nature -> { - if (CollectionUtils.isEmpty(possibleModels)) { - return true; - } - Long model = NatureHelper.getDataSetId(nature); - return possibleModels.contains(model); - }) - .map( - nature -> { - DictWord posDO = new DictWord(); - posDO.setWord(entry.getName()); - posDO.setNature(nature); - return posDO; - })) - .sorted(Comparator.comparingInt(a -> a.getWord().length())) - .collect( - Collectors.toMap( - DictWord::getNature, - DictWord::getWord, - (value1, value2) -> value1, - LinkedHashMap::new)); + .flatMap(entry -> entry.getNatures().stream().filter(nature -> { + if (CollectionUtils.isEmpty(possibleModels)) { + return true; + } + Long model = NatureHelper.getDataSetId(nature); + return possibleModels.contains(model); + }).map(nature -> { + DictWord posDO = new DictWord(); + posDO.setWord(entry.getName()); + posDO.setNature(nature); + return posDO; + })).sorted(Comparator.comparingInt(a -> a.getWord().length())) + .collect(Collectors.toMap(DictWord::getNature, DictWord::getWord, + (value1, value2) -> value1, LinkedHashMap::new)); } - private boolean searchMetricAndDimension( - Set possibleDataSets, + private boolean searchMetricAndDimension(Set possibleDataSets, Map modelToName, Map.Entry> searchTextEntry, Set searchResults) { @@ -306,15 +259,12 @@ public class RetrieveServiceImpl implements RetrieveService { for (HanlpMapResult hanlpMapResult : hanlpMapResults) { - List dimensionMetricClassIds = - hanlpMapResult.getNatures().stream() - .map( - nature -> - new ModelWithSemanticType( - NatureHelper.getDataSetId(nature), - NatureHelper.convertToElementType(nature))) - .filter(entry -> matchCondition(entry, possibleDataSets)) - .collect(Collectors.toList()); + List dimensionMetricClassIds = hanlpMapResult.getNatures() + .stream() + .map(nature -> new ModelWithSemanticType(NatureHelper.getDataSetId(nature), + NatureHelper.convertToElementType(nature))) + .filter(entry -> matchCondition(entry, possibleDataSets)) + .collect(Collectors.toList()); if (CollectionUtils.isEmpty(dimensionMetricClassIds)) { continue; @@ -324,21 +274,15 @@ public class RetrieveServiceImpl implements RetrieveService { Long modelId = modelWithSemanticType.getModel(); SchemaElementType schemaElementType = modelWithSemanticType.getSchemaElementType(); SearchResult searchResult = - SearchResult.builder() - .modelId(modelId) - .modelName(modelToName.get(modelId)) + SearchResult.builder().modelId(modelId).modelName(modelToName.get(modelId)) .recommend(matchText.getRegText() + hanlpMapResult.getName()) .subRecommend(hanlpMapResult.getName()) - .schemaElementType(schemaElementType) - .build(); - // visibility to filter metrics + .schemaElementType(schemaElementType).build(); + // visibility to filter metrics searchResults.add(searchResult); } - log.debug( - "parseResult:{},dimensionMetricClassIds:{},possibleDataSets:{}", - hanlpMapResult, - dimensionMetricClassIds, - possibleDataSets); + log.debug("parseResult:{},dimensionMetricClassIds:{},possibleDataSets:{}", + hanlpMapResult, dimensionMetricClassIds, possibleDataSets); } log.info("searchMetricAndDimension searchResults:{}", searchResults); return existMetric; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java index da5542d8d..410f2e463 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/SchemaServiceImpl.java @@ -103,17 +103,10 @@ public class SchemaServiceImpl implements SchemaService { @Value("${s2.schema.cache.enable:true}") private boolean schemaCacheEnable; - public SchemaServiceImpl( - ModelService modelService, - DimensionService dimensionService, - MetricService metricService, - DomainService domainService, - DataSetService dataSetService, - ModelRelaService modelRelaService, - StatUtils statUtils, - TagMetaService tagService, - TermService termService, - DatabaseService databaseService) { + public SchemaServiceImpl(ModelService modelService, DimensionService dimensionService, + MetricService metricService, DomainService domainService, DataSetService dataSetService, + ModelRelaService modelRelaService, StatUtils statUtils, TagMetaService tagService, + TermService termService, DatabaseService databaseService) { this.modelService = modelService; this.dimensionService = dimensionService; this.metricService = metricService; @@ -142,8 +135,7 @@ public class SchemaServiceImpl implements SchemaService { if (dataSetId == null) { return null; } - return fetchDataSetSchema(new DataSetFilterReq(dataSetId)).stream() - .findFirst() + return fetchDataSetSchema(new DataSetFilterReq(dataSetId)).stream().findFirst() .orElse(null); } @@ -159,10 +151,8 @@ public class SchemaServiceImpl implements SchemaService { ids.add(dataSetId); List dataSetSchemaResps = fetchDataSetSchema(ids); if (!CollectionUtils.isEmpty(dataSetSchemaResps)) { - Optional dataSetSchemaResp = - dataSetSchemaResps.stream() - .filter(d -> d.getId().equals(dataSetId)) - .findFirst(); + Optional dataSetSchemaResp = dataSetSchemaResps.stream() + .filter(d -> d.getId().equals(dataSetId)).findFirst(); if (dataSetSchemaResp.isPresent()) { DataSetSchemaResp dataSetSchema = dataSetSchemaResp.get(); return DataSetSchemaBuilder.build(dataSetSchema); @@ -199,11 +189,8 @@ public class SchemaServiceImpl implements SchemaService { Map dataSetRespMap = getDataSetMap(dataSetResps); Set domainIds = dataSetResps.stream().map(DataSetResp::getDomainId).collect(Collectors.toSet()); - List modelIds = - dataSetRespMap.values().stream() - .map(DataSetResp::getAllModels) - .flatMap(Collection::stream) - .collect(Collectors.toList()); + List modelIds = dataSetRespMap.values().stream().map(DataSetResp::getAllModels) + .flatMap(Collection::stream).collect(Collectors.toList()); Map> termMaps = termService.getTermSets(domainIds); metaFilter.setModelIds(modelIds); @@ -228,28 +215,22 @@ public class SchemaServiceImpl implements SchemaService { } List metricSchemaResps = MetricConverter.filterByDataSet(metricResps, dataSetResp).stream() - .map(this::convert) - .collect(Collectors.toList()); + .map(this::convert).collect(Collectors.toList()); List dimSchemaResps = DimensionConverter.filterByDataSet(dimensionResps, dataSetResp).stream() - .map(this::convert) - .collect(Collectors.toList()); + .map(this::convert).collect(Collectors.toList()); DataSetSchemaResp dataSetSchemaResp = new DataSetSchemaResp(); BeanUtils.copyProperties(dataSetResp, dataSetSchemaResp); dataSetSchemaResp.setDimensions(dimSchemaResps); dataSetSchemaResp.setMetrics(metricSchemaResps); - dataSetSchemaResp.setModelResps( - modelResps.stream() - .filter( - modelResp -> - dataSetResp.getAllModels().contains(modelResp.getId())) - .collect(Collectors.toList())); + dataSetSchemaResp.setModelResps(modelResps.stream() + .filter(modelResp -> dataSetResp.getAllModels().contains(modelResp.getId())) + .collect(Collectors.toList())); dataSetSchemaResp.setTermResps( termMaps.getOrDefault(dataSetResp.getDomainId(), Lists.newArrayList())); if (!CollectionUtils.isEmpty(dataSetSchemaResp.getModelResps())) { - DatabaseResp databaseResp = - databaseService.getDatabase( - dataSetSchemaResp.getModelResps().get(0).getDatabaseId()); + DatabaseResp databaseResp = databaseService + .getDatabase(dataSetSchemaResp.getModelResps().get(0).getDatabaseId()); dataSetSchemaResp.setDatabaseType(databaseResp.getType()); } dataSetSchemaResps.add(dataSetSchemaResp); @@ -265,9 +246,8 @@ public class SchemaServiceImpl implements SchemaService { } MetaFilter metaFilter = new MetaFilter(modelIds); metaFilter.setStatus(StatusEnum.ONLINE.getCode()); - Map> metricRespMap = - metricService.getMetrics(metaFilter).stream() - .collect(Collectors.groupingBy(MetricResp::getModelId)); + Map> metricRespMap = metricService.getMetrics(metaFilter).stream() + .collect(Collectors.groupingBy(MetricResp::getModelId)); Map> dimensionRespsMap = dimensionService.getDimensions(metaFilter).stream() .collect(Collectors.groupingBy(DimensionResp::getModelId)); @@ -285,19 +265,15 @@ public class SchemaServiceImpl implements SchemaService { metricResps.stream().map(this::convert).collect(Collectors.toList()); List dimensionResps = dimensionRespsMap.getOrDefault(modelId, Lists.newArrayList()).stream() - .map(this::convert) - .collect(Collectors.toList()); + .map(this::convert).collect(Collectors.toList()); ModelSchemaResp modelSchemaResp = new ModelSchemaResp(); BeanUtils.copyProperties(modelResp, modelSchemaResp); modelSchemaResp.setDimensions(dimensionResps); modelSchemaResp.setMetrics(metricSchemaResps); - modelSchemaResp.setModelRelas( - modelRelas.stream() - .filter( - modelRela -> - modelRela.getFromModelId().equals(modelId) - || modelRela.getToModelId().equals(modelId)) - .collect(Collectors.toList())); + modelSchemaResp.setModelRelas(modelRelas.stream() + .filter(modelRela -> modelRela.getFromModelId().equals(modelId) + || modelRela.getToModelId().equals(modelId)) + .collect(Collectors.toList())); modelSchemaResps.add(modelSchemaResp); } return modelSchemaResps; @@ -305,17 +281,11 @@ public class SchemaServiceImpl implements SchemaService { private void fillCnt(List dataSetSchemaResps, List statInfos) { - Map typeIdAndStatPair = - statInfos.stream() - .collect( - Collectors.toMap( - itemUseInfo -> - itemUseInfo.getType() - + AT_SYMBOL - + AT_SYMBOL - + itemUseInfo.getBizName(), - itemUseInfo -> itemUseInfo, - (item1, item2) -> item1)); + Map typeIdAndStatPair = statInfos.stream() + .collect(Collectors.toMap( + itemUseInfo -> itemUseInfo.getType() + AT_SYMBOL + AT_SYMBOL + + itemUseInfo.getBizName(), + itemUseInfo -> itemUseInfo, (item1, item2) -> item1)); log.debug("typeIdAndStatPair:{}", typeIdAndStatPair); for (DataSetSchemaResp dataSetSchemaResp : dataSetSchemaResps) { fillDimCnt(dataSetSchemaResp, typeIdAndStatPair); @@ -323,49 +293,39 @@ public class SchemaServiceImpl implements SchemaService { } } - private void fillMetricCnt( - DataSetSchemaResp dataSetSchemaResp, Map typeIdAndStatPair) { + private void fillMetricCnt(DataSetSchemaResp dataSetSchemaResp, + Map typeIdAndStatPair) { List metrics = dataSetSchemaResp.getMetrics(); if (CollectionUtils.isEmpty(dataSetSchemaResp.getMetrics())) { return; } if (!CollectionUtils.isEmpty(metrics)) { - metrics.stream() - .forEach( - metric -> { - String key = - TypeEnums.METRIC.name().toLowerCase() - + AT_SYMBOL - + AT_SYMBOL - + metric.getBizName(); - if (typeIdAndStatPair.containsKey(key)) { - metric.setUseCnt(typeIdAndStatPair.get(key).getUseCnt()); - } - }); + metrics.stream().forEach(metric -> { + String key = TypeEnums.METRIC.name().toLowerCase() + AT_SYMBOL + AT_SYMBOL + + metric.getBizName(); + if (typeIdAndStatPair.containsKey(key)) { + metric.setUseCnt(typeIdAndStatPair.get(key).getUseCnt()); + } + }); } dataSetSchemaResp.setMetrics(metrics); } - private void fillDimCnt( - DataSetSchemaResp dataSetSchemaResp, Map typeIdAndStatPair) { + private void fillDimCnt(DataSetSchemaResp dataSetSchemaResp, + Map typeIdAndStatPair) { List dimensions = dataSetSchemaResp.getDimensions(); if (CollectionUtils.isEmpty(dataSetSchemaResp.getDimensions())) { return; } if (!CollectionUtils.isEmpty(dimensions)) { - dimensions.stream() - .forEach( - dim -> { - String key = - TypeEnums.DIMENSION.name().toLowerCase() - + AT_SYMBOL - + AT_SYMBOL - + dim.getBizName(); - if (typeIdAndStatPair.containsKey(key)) { - dim.setUseCnt(typeIdAndStatPair.get(key).getUseCnt()); - } - }); + dimensions.stream().forEach(dim -> { + String key = TypeEnums.DIMENSION.name().toLowerCase() + AT_SYMBOL + AT_SYMBOL + + dim.getBizName(); + if (typeIdAndStatPair.containsKey(key)) { + dim.setUseCnt(typeIdAndStatPair.get(key).getUseCnt()); + } + }); } dataSetSchemaResp.setDimensions(dimensions); } @@ -404,11 +364,9 @@ public class SchemaServiceImpl implements SchemaService { public List getModelList(List modelIds) { List modelRespList = new ArrayList<>(); if (!org.apache.commons.collections.CollectionUtils.isEmpty(modelIds)) { - modelIds.stream() - .forEach( - m -> { - modelRespList.add(modelService.getModel(m)); - }); + modelIds.stream().forEach(m -> { + modelRespList.add(modelService.getModel(m)); + }); } return modelRespList; } @@ -434,21 +392,14 @@ public class SchemaServiceImpl implements SchemaService { } else if (!CollectionUtils.isEmpty(schemaFilterReq.getModelIds())) { List modelSchemaResps = fetchModelSchemaResps(schemaFilterReq.getModelIds()); - semanticSchemaResp.setMetrics( - modelSchemaResps.stream() - .map(ModelSchemaResp::getMetrics) - .flatMap(Collection::stream) - .collect(Collectors.toList())); - semanticSchemaResp.setDimensions( - modelSchemaResps.stream() - .map(ModelSchemaResp::getDimensions) - .flatMap(Collection::stream) - .collect(Collectors.toList())); - semanticSchemaResp.setModelRelas( - modelSchemaResps.stream() - .map(ModelSchemaResp::getModelRelas) - .flatMap(Collection::stream) - .collect(Collectors.toList())); + semanticSchemaResp.setMetrics(modelSchemaResps.stream().map(ModelSchemaResp::getMetrics) + .flatMap(Collection::stream).collect(Collectors.toList())); + semanticSchemaResp + .setDimensions(modelSchemaResps.stream().map(ModelSchemaResp::getDimensions) + .flatMap(Collection::stream).collect(Collectors.toList())); + semanticSchemaResp + .setModelRelas(modelSchemaResps.stream().map(ModelSchemaResp::getModelRelas) + .flatMap(Collection::stream).collect(Collectors.toList())); semanticSchemaResp.setModelResps( modelSchemaResps.stream().map(this::convert).collect(Collectors.toList())); semanticSchemaResp.setSchemaType(SchemaType.MODEL); @@ -485,13 +436,11 @@ public class SchemaServiceImpl implements SchemaService { @Override public List getStatInfo(ItemUseReq itemUseReq) { if (itemUseReq.getCacheEnable()) { - return itemUseCache.get( - JsonUtil.toString(itemUseReq), - () -> { - List data = statUtils.getStatInfo(itemUseReq); - itemUseCache.put(JsonUtil.toString(itemUseReq), data); - return data; - }); + return itemUseCache.get(JsonUtil.toString(itemUseReq), () -> { + List data = statUtils.getStatInfo(itemUseReq); + itemUseCache.put(JsonUtil.toString(itemUseReq), data); + return data; + }); } return statUtils.getStatInfo(itemUseReq); } @@ -499,28 +448,17 @@ public class SchemaServiceImpl implements SchemaService { @Override public List getDomainDataSetTree() { List domainResps = domainService.getDomainList(); - List itemResps = - domainResps.stream() - .map( - domain -> - new ItemResp( - domain.getId(), - domain.getParentId(), - domain.getName(), - TypeEnums.DOMAIN)) - .collect(Collectors.toList()); + List itemResps = domainResps.stream().map(domain -> new ItemResp(domain.getId(), + domain.getParentId(), domain.getName(), TypeEnums.DOMAIN)) + .collect(Collectors.toList()); Map itemRespMap = itemResps.stream().collect(Collectors.toMap(ItemResp::getId, item -> item)); List dataSetResps = dataSetService.getDataSetList(new MetaFilter()); for (DataSetResp dataSetResp : dataSetResps) { ItemResp itemResp = itemRespMap.get(dataSetResp.getDomainId()); if (itemResp != null) { - ItemResp dataSet = - new ItemResp( - dataSetResp.getId(), - dataSetResp.getDomainId(), - dataSetResp.getName(), - TypeEnums.DATASET); + ItemResp dataSet = new ItemResp(dataSetResp.getId(), dataSetResp.getDomainId(), + dataSetResp.getName(), TypeEnums.DATASET); itemResp.getChildren().add(dataSet); } } @@ -528,10 +466,8 @@ public class SchemaServiceImpl implements SchemaService { } private void fillStaticInfo(List dataSetSchemaResps) { - List dataSetIds = - dataSetSchemaResps.stream() - .map(DataSetSchemaResp::getId) - .collect(Collectors.toList()); + List dataSetIds = dataSetSchemaResps.stream().map(DataSetSchemaResp::getId) + .collect(Collectors.toList()); ItemUseReq itemUseReq = new ItemUseReq(); itemUseReq.setModelIds(dataSetIds); @@ -567,11 +503,9 @@ public class SchemaServiceImpl implements SchemaService { } @Override - public void getSchemaYamlTpl( - SemanticSchemaResp semanticSchemaResp, + public void getSchemaYamlTpl(SemanticSchemaResp semanticSchemaResp, Map> dimensionYamlMap, - List dataModelYamlTplList, - List metricYamlTplList, + List dataModelYamlTplList, List metricYamlTplList, Map modelIdName) { List modelResps = semanticSchemaResp.getModelResps(); @@ -587,15 +521,10 @@ public class SchemaServiceImpl implements SchemaService { if (!dimensionYamlMap.containsKey(modelResp.getBizName())) { dimensionYamlMap.put(modelResp.getBizName(), new ArrayList<>()); } - List dimensionRespList = - dimensionResps.stream() - .filter( - d -> - d.getModelBizName() - .equalsIgnoreCase(modelResp.getBizName())) - .collect(Collectors.toList()); - dimensionYamlMap - .get(modelResp.getBizName()) + List dimensionRespList = dimensionResps.stream() + .filter(d -> d.getModelBizName().equalsIgnoreCase(modelResp.getBizName())) + .collect(Collectors.toList()); + dimensionYamlMap.get(modelResp.getBizName()) .addAll(DimensionYamlManager.convert2DimensionYaml(dimensionRespList)); } List metricResps = new ArrayList<>(semanticSchemaResp.getMetrics()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagMetaServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagMetaServiceImpl.java index 5c2beb538..a59f93a08 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagMetaServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagMetaServiceImpl.java @@ -58,13 +58,9 @@ public class TagMetaServiceImpl implements TagMetaService { private final TagObjectService tagObjectService; private final DomainService domainService; - public TagMetaServiceImpl( - TagRepository tagRepository, - ModelService modelService, - CollectService collectService, - @Lazy DimensionService dimensionService, - @Lazy MetricService metricService, - TagObjectService tagObjectService, + public TagMetaServiceImpl(TagRepository tagRepository, ModelService modelService, + CollectService collectService, @Lazy DimensionService dimensionService, + @Lazy MetricService metricService, TagObjectService tagObjectService, DomainService domainService) { this.tagRepository = tagRepository; this.modelService = modelService; @@ -162,11 +158,8 @@ public class TagMetaServiceImpl implements TagMetaService { if (Objects.nonNull(tagMarketPageReq.getTagObjectId())) { modelRespList = modelRespList.stream() - .filter( - modelResp -> - tagMarketPageReq - .getTagObjectId() - .equals(modelResp.getTagObjectId())) + .filter(modelResp -> tagMarketPageReq.getTagObjectId() + .equals(modelResp.getTagObjectId())) .collect(Collectors.toList()); } if (CollectionUtils.isEmpty(modelRespList)) { @@ -179,15 +172,9 @@ public class TagMetaServiceImpl implements TagMetaService { BeanUtils.copyProperties(tagMarketPageReq, tagFilter); List collectList = collectService.getCollectionList(user.getName()); if (tagMarketPageReq.isHasCollect()) { - List collectIds = - collectList.stream() - .filter( - collectDO -> - SchemaElementType.TAG - .name() - .equalsIgnoreCase(collectDO.getType())) - .map(CollectDO::getCollectId) - .collect(Collectors.toList()); + List collectIds = collectList.stream().filter( + collectDO -> SchemaElementType.TAG.name().equalsIgnoreCase(collectDO.getType())) + .map(CollectDO::getCollectId).collect(Collectors.toList()); if (CollectionUtils.isEmpty(collectIds)) { tagFilter.setIds(Lists.newArrayList(-1L)); } else { @@ -217,22 +204,14 @@ public class TagMetaServiceImpl implements TagMetaService { if (CollectionUtils.isEmpty(tagObjects)) { return; } - Map tagObjectMap = - tagObjects.stream() - .collect( - Collectors.toMap( - TagObjectResp::getId, - tagObject -> tagObject, - (v1, v2) -> v2)); + Map tagObjectMap = tagObjects.stream().collect( + Collectors.toMap(TagObjectResp::getId, tagObject -> tagObject, (v1, v2) -> v2)); if (CollectionUtils.isNotEmpty(tagRespList)) { - tagRespList.stream() - .forEach( - tagResp -> { - if (tagObjectMap.containsKey(tagResp.getTagObjectId())) { - tagResp.setTagObjectName( - tagObjectMap.get(tagResp.getTagObjectId()).getName()); - } - }); + tagRespList.stream().forEach(tagResp -> { + if (tagObjectMap.containsKey(tagResp.getTagObjectId())) { + tagResp.setTagObjectName(tagObjectMap.get(tagResp.getTagObjectId()).getName()); + } + }); } } @@ -246,20 +225,14 @@ public class TagMetaServiceImpl implements TagMetaService { } private void fillDomainInfo(List tagRespList) { - Map domainMap = - domainService.getDomainList().stream() - .collect( - Collectors.toMap( - DomainResp::getId, domain -> domain, (v1, v2) -> v2)); + Map domainMap = domainService.getDomainList().stream() + .collect(Collectors.toMap(DomainResp::getId, domain -> domain, (v1, v2) -> v2)); if (CollectionUtils.isNotEmpty(tagRespList) && Objects.nonNull(domainMap)) { - tagRespList.stream() - .forEach( - tagResp -> { - if (domainMap.containsKey(tagResp.getDomainId())) { - tagResp.setDomainName( - domainMap.get(tagResp.getDomainId()).getName()); - } - }); + tagRespList.stream().forEach(tagResp -> { + if (domainMap.containsKey(tagResp.getDomainId())) { + tagResp.setDomainName(domainMap.get(tagResp.getDomainId()).getName()); + } + }); } } @@ -321,31 +294,21 @@ public class TagMetaServiceImpl implements TagMetaService { tagRespList.stream().map(TagResp::getModelId).collect(Collectors.toList()); ModelFilter modelFilter = new ModelFilter(false, modelIds); Map modelIdAndRespMap = modelService.getModelMap(modelFilter); - tagRespList.stream() - .forEach( - tagResp -> { - if (Objects.nonNull(modelIdAndRespMap) - && modelIdAndRespMap.containsKey(tagResp.getModelId())) { - tagResp.setModelName( - modelIdAndRespMap.get(tagResp.getModelId()).getName()); - tagResp.setDomainId( - modelIdAndRespMap.get(tagResp.getModelId()).getDomainId()); - tagResp.setTagObjectId( - modelIdAndRespMap - .get(tagResp.getModelId()) - .getTagObjectId()); - } - }); + tagRespList.stream().forEach(tagResp -> { + if (Objects.nonNull(modelIdAndRespMap) + && modelIdAndRespMap.containsKey(tagResp.getModelId())) { + tagResp.setModelName(modelIdAndRespMap.get(tagResp.getModelId()).getName()); + tagResp.setDomainId(modelIdAndRespMap.get(tagResp.getModelId()).getDomainId()); + tagResp.setTagObjectId( + modelIdAndRespMap.get(tagResp.getModelId()).getTagObjectId()); + } + }); } private TagResp fillCollectAndAdminInfo(TagResp tagResp, User user) { - List collectIds = - collectService.getCollectionList(user.getName()).stream() - .filter( - collectDO -> - TypeEnums.TAG.name().equalsIgnoreCase(collectDO.getType())) - .map(CollectDO::getCollectId) - .collect(Collectors.toList()); + List collectIds = collectService.getCollectionList(user.getName()).stream() + .filter(collectDO -> TypeEnums.TAG.name().equalsIgnoreCase(collectDO.getType())) + .map(CollectDO::getCollectId).collect(Collectors.toList()); if (CollectionUtils.isNotEmpty(collectIds) && collectIds.contains(tagResp.getId())) { tagResp.setIsCollect(true); } else { @@ -357,24 +320,17 @@ public class TagMetaServiceImpl implements TagMetaService { } private TagResp fillCollectAndAdminInfo(List tagRespList, User user) { - List collectIds = - collectService.getCollectionList(user.getName()).stream() - .filter( - collectDO -> - TypeEnums.TAG.name().equalsIgnoreCase(collectDO.getType())) - .map(CollectDO::getCollectId) - .collect(Collectors.toList()); + List collectIds = collectService.getCollectionList(user.getName()).stream() + .filter(collectDO -> TypeEnums.TAG.name().equalsIgnoreCase(collectDO.getType())) + .map(CollectDO::getCollectId).collect(Collectors.toList()); - tagRespList.stream() - .forEach( - tagResp -> { - if (CollectionUtils.isNotEmpty(collectIds) - && collectIds.contains(tagResp.getId())) { - tagResp.setIsCollect(true); - } else { - tagResp.setIsCollect(false); - } - }); + tagRespList.stream().forEach(tagResp -> { + if (CollectionUtils.isNotEmpty(collectIds) && collectIds.contains(tagResp.getId())) { + tagResp.setIsCollect(true); + } else { + tagResp.setIsCollect(false); + } + }); fillAdminRes(tagRespList, user); return tagRespList.get(0); @@ -418,21 +374,17 @@ public class TagMetaServiceImpl implements TagMetaService { ModelResp model = modelService.getModel(dimension.getModelId()); if (Objects.isNull(model.getTagObjectId())) { throw new RuntimeException( - String.format( - "this dimension:%s is not supported to create tag," - + " no related tag object", - tagReq.getItemId())); + String.format("this dimension:%s is not supported to create tag," + + " no related tag object", tagReq.getItemId())); } } if (TagDefineType.METRIC.equals(tagReq.getTagDefineType())) { MetricResp metric = metricService.getMetric(tagReq.getItemId()); ModelResp model = modelService.getModel(metric.getModelId()); if (Objects.isNull(model.getTagObjectId())) { - throw new RuntimeException( - String.format( - "this metric:%s is not supported to create tag," - + " no related tag object", - tagReq.getItemId())); + throw new RuntimeException(String.format( + "this metric:%s is not supported to create tag," + " no related tag object", + tagReq.getItemId())); } } } @@ -451,15 +403,11 @@ public class TagMetaServiceImpl implements TagMetaService { tagFilter.setItemIds(itemIds); Set dimensionItemSet = getTagDOList(tagFilter).stream().map(TagDO::getItemId).collect(Collectors.toSet()); - return itemIds.stream() - .map( - entry -> { - TagItem tagItem = new TagItem(); - tagItem.setIsTag( - Boolean.compare(dimensionItemSet.contains(entry), false)); - tagItem.setItemId(entry); - return tagItem; - }) - .collect(Collectors.toList()); + return itemIds.stream().map(entry -> { + TagItem tagItem = new TagItem(); + tagItem.setIsTag(Boolean.compare(dimensionItemSet.contains(entry), false)); + tagItem.setItemId(entry); + return tagItem; + }).collect(Collectors.toList()); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagObjectServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagObjectServiceImpl.java index 96aa360e8..6e2d5b346 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagObjectServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagObjectServiceImpl.java @@ -34,9 +34,7 @@ public class TagObjectServiceImpl implements TagObjectService { private final ModelService modelService; private final TagMetaService tagMetaService; - public TagObjectServiceImpl( - TagObjectRepository tagObjectRepository, - ModelService modelService, + public TagObjectServiceImpl(TagObjectRepository tagObjectRepository, ModelService modelService, @Lazy TagMetaService tagMetaService) { this.tagObjectRepository = tagObjectRepository; this.modelService = modelService; @@ -66,14 +64,9 @@ public class TagObjectServiceImpl implements TagObjectService { if (CollectionUtils.isEmpty(tagObjectRespList)) { return; } - tagObjectRespList = - tagObjectRespList.stream() - .filter( - tagObjectResp -> - StatusEnum.ONLINE - .getCode() - .equals(tagObjectResp.getStatus())) - .collect(Collectors.toList()); + tagObjectRespList = tagObjectRespList.stream().filter( + tagObjectResp -> StatusEnum.ONLINE.getCode().equals(tagObjectResp.getStatus())) + .collect(Collectors.toList()); for (TagObjectResp tagObject : tagObjectRespList) { if (tagObject.getBizName().equalsIgnoreCase(tagObjectReq.getBizName())) { throw new Exception( @@ -128,9 +121,8 @@ public class TagObjectServiceImpl implements TagObjectService { if (!CollectionUtils.isEmpty(allModelByDomainIds)) { List modelIds = allModelByDomainIds.stream().map(ModelResp::getId).collect(Collectors.toList()); - throw new Exception( - "delete operation is not supported at the moment. related modelIds:" - + modelIds); + throw new Exception("delete operation is not supported at the moment. related modelIds:" + + modelIds); } TagFilterPageReq tagMarketPageReq = new TagFilterPageReq(); tagMarketPageReq.setTagObjectId(tagObjectDO.getId()); @@ -173,9 +165,8 @@ public class TagObjectServiceImpl implements TagObjectService { List tagObjectDOList = tagObjectRepository.query(filter); List tagObjectRespList = TagObjectConverter.convert2RespList(tagObjectDOList); - Map map = - tagObjectRespList.stream() - .collect(Collectors.toMap(TagObjectResp::getId, a -> a, (k1, k2) -> k1)); + Map map = tagObjectRespList.stream() + .collect(Collectors.toMap(TagObjectResp::getId, a -> a, (k1, k2) -> k1)); return map; } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagQueryServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagQueryServiceImpl.java index f79a15223..394137f13 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagQueryServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TagQueryServiceImpl.java @@ -52,11 +52,8 @@ public class TagQueryServiceImpl implements TagQueryService { private final ModelService modelService; private final SqlGenerateUtils sqlGenerateUtils; - public TagQueryServiceImpl( - TagMetaService tagMetaService, - SemanticLayerService queryService, - ModelService modelService, - SqlGenerateUtils sqlGenerateUtils) { + public TagQueryServiceImpl(TagMetaService tagMetaService, SemanticLayerService queryService, + ModelService modelService, SqlGenerateUtils sqlGenerateUtils) { this.tagMetaService = tagMetaService; this.queryService = queryService; this.modelService = modelService; @@ -88,9 +85,8 @@ public class TagQueryServiceImpl implements TagQueryService { private void checkTag(TagResp tag) throws Exception { if (Objects.nonNull(tag) && TagDefineType.METRIC.name().equalsIgnoreCase(tag.getTagDefineType())) { - throw new Exception( - "do not support value distribution query for tag (from metric): " - + tag.getBizName()); + throw new Exception("do not support value distribution query for tag (from metric): " + + tag.getBizName()); } } @@ -121,16 +117,12 @@ public class TagQueryServiceImpl implements TagQueryService { return LocalDate.now().plusDays(-dayBefore).format(formatter); } - private String queryTagDateFromDbBySql( - Dim dim, TagResp tag, ItemValueReq itemValueReq, User user) { + private String queryTagDateFromDbBySql(Dim dim, TagResp tag, ItemValueReq itemValueReq, + User user) { String sqlPattern = "select max(%s) as %s from tbl where %s is not null"; - String sql = - String.format( - sqlPattern, - TimeDimensionEnum.DAY.getName(), - maxDateAlias, - tag.getBizName()); + String sql = String.format(sqlPattern, TimeDimensionEnum.DAY.getName(), maxDateAlias, + tag.getBizName()); // 添加时间过滤信息 log.info("[queryTagDateFromDbBySql] calculate the maximum time start"); @@ -143,22 +135,13 @@ public class TagQueryServiceImpl implements TagQueryService { if (StringUtils.isEmpty(dateFormat)) { dateFormat = itemValueDateFormat; } - String start = - LocalDate.now() - .minusDays(itemValueReq.getDateConf().getUnit()) - .format(DateTimeFormatter.ofPattern(dateFormat)); - String end = - LocalDate.now() - .minusDays(0) - .format(DateTimeFormatter.ofPattern(dateFormat)); - sql = - sql - + String.format( - " and ( %s > '%s' and %s <= '%s' )", - TimeDimensionEnum.DAY.getName(), - start, - TimeDimensionEnum.DAY.getName(), - end); + String start = LocalDate.now().minusDays(itemValueReq.getDateConf().getUnit()) + .format(DateTimeFormatter.ofPattern(dateFormat)); + String end = LocalDate.now().minusDays(0) + .format(DateTimeFormatter.ofPattern(dateFormat)); + sql = sql + String.format(" and ( %s > '%s' and %s <= '%s' )", + TimeDimensionEnum.DAY.getName(), start, TimeDimensionEnum.DAY.getName(), + end); } } } @@ -219,42 +202,27 @@ public class TagQueryServiceImpl implements TagQueryService { return " and " + dateWhereClause; } - private void fillTagValueInfo( - ItemValueResp itemValueResp, SemanticQueryResp semanticQueryResp, Long totalCount) { + private void fillTagValueInfo(ItemValueResp itemValueResp, SemanticQueryResp semanticQueryResp, + Long totalCount) { List valueDistributionList = new ArrayList<>(); List> resultList = semanticQueryResp.getResultList(); if (!CollectionUtils.isEmpty(resultList)) { - resultList.stream() - .forEach( - line -> { - Object tagValue = line.get(itemValueResp.getBizName()); - Long tagValueCount = - Long.parseLong(line.get(tagValueAlias).toString()); - valueDistributionList.add( - ValueDistribution.builder() - .totalCount(totalCount) - .valueMap(tagValue) - .valueCount(tagValueCount) - .ratio(1.0 * tagValueCount / totalCount) - .build()); - }); + resultList.stream().forEach(line -> { + Object tagValue = line.get(itemValueResp.getBizName()); + Long tagValueCount = Long.parseLong(line.get(tagValueAlias).toString()); + valueDistributionList.add(ValueDistribution.builder().totalCount(totalCount) + .valueMap(tagValue).valueCount(tagValueCount) + .ratio(1.0 * tagValueCount / totalCount).build()); + }); } itemValueResp.setValueDistributionList(valueDistributionList); } private QuerySqlReq generateReq(TagResp tag, ItemValueReq itemValueReq) { - String sqlPattern = - "select %s, count(1) as %s from tbl where %s is not null %s " - + "group by %s order by %s desc"; - String sql = - String.format( - sqlPattern, - tag.getBizName(), - tagValueAlias, - tag.getBizName(), - getDateFilter(itemValueReq), - tag.getBizName(), - tag.getBizName()); + String sqlPattern = "select %s, count(1) as %s from tbl where %s is not null %s " + + "group by %s order by %s desc"; + String sql = String.format(sqlPattern, tag.getBizName(), tagValueAlias, tag.getBizName(), + getDateFilter(itemValueReq), tag.getBizName(), tag.getBizName()); Set modelIds = new HashSet<>(); modelIds.add(tag.getModelId()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java index b116d0f45..6a3a8873f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/service/impl/TermServiceImpl.java @@ -56,15 +56,8 @@ public class TermServiceImpl extends ServiceImpl implements QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.lambda().eq(TermDO::getDomainId, domainId); if (StringUtils.isNotBlank(queryKey)) { - queryWrapper - .lambda() - .and( - i -> - i.like(TermDO::getName, queryKey) - .or() - .like(TermDO::getDescription, queryKey) - .or() - .like(TermDO::getAlias, queryKey)); + queryWrapper.lambda().and(i -> i.like(TermDO::getName, queryKey).or() + .like(TermDO::getDescription, queryKey).or().like(TermDO::getAlias, queryKey)); } List termDOS = list(queryWrapper); return termDOS.stream().map(this::convert).collect(Collectors.toList()); @@ -78,8 +71,7 @@ public class TermServiceImpl extends ServiceImpl implements QueryWrapper queryWrapper = new QueryWrapper<>(); queryWrapper.lambda().in(TermDO::getDomainId, domainIds); List list = list(queryWrapper); - return list.stream() - .map(this::convert) + return list.stream().map(this::convert) .collect(Collectors.groupingBy(TermResp::getDomainId)); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/DictionaryReloadTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/DictionaryReloadTask.java index 397b41854..75578ad1a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/DictionaryReloadTask.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/DictionaryReloadTask.java @@ -13,7 +13,8 @@ import org.springframework.stereotype.Component; @Order(2) public class DictionaryReloadTask implements CommandLineRunner { - @Autowired private DictWordService dictWordService; + @Autowired + private DictWordService dictWordService; @Override public void run(String... args) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java index ab334dfaa..52610b9ce 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/task/MetaEmbeddingTask.java @@ -25,13 +25,17 @@ import java.util.List; @Order(2) public class MetaEmbeddingTask implements CommandLineRunner { - @Autowired private EmbeddingService embeddingService; + @Autowired + private EmbeddingService embeddingService; - @Autowired private EmbeddingConfig embeddingConfig; + @Autowired + private EmbeddingConfig embeddingConfig; - @Autowired private MetricService metricService; + @Autowired + private MetricService metricService; - @Autowired private DimensionService dimensionService; + @Autowired + private DimensionService dimensionService; @PreDestroy public void onShutdown() { @@ -62,13 +66,11 @@ public class MetaEmbeddingTask implements CommandLineRunner { try { List metricDataItems = metricService.getDataEvent().getDataItems(); - embeddingService.addQuery( - embeddingConfig.getMetaCollectionName(), + embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), TextSegmentConvert.convertToEmbedding(metricDataItems)); List dimensionDataItems = dimensionService.getDataEvent().getDataItems(); - embeddingService.addQuery( - embeddingConfig.getMetaCollectionName(), + embeddingService.addQuery(embeddingConfig.getMetaCollectionName(), TextSegmentConvert.convertToEmbedding(dimensionDataItems)); } catch (Exception e) { log.error("Failed to reload meta embedding.", e); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java index 429c1bf8b..bff38ca10 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelper.java @@ -23,38 +23,32 @@ public class AliasGenerateHelper { private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); - private static final String NAME_ALIAS_INSTRUCTION = - "" - + "\n#Role: You are a professional data analyst specializing in metrics and dimensions." - + "\n#Task: You will be provided with metadata about a metric or dimension, please help " - + "generate a few aliases in the same language as its `fieldName`." - + "\n#Rules:" - + "1. Please do not generate aliases like xxx1, xxx2, xxx3." - + "2. Please do not generate aliases that are the same as the original names of metrics/dimensions." - + "3. Please pay attention to the quality of the generated aliases and " - + "avoid creating aliases that look like test data." - + "4. Please output as a json string array." - + "\n#Metadata: {'table':'{{table}}', 'name':'{{name}}', 'type':'{{type}}', " - + "'field':'field', 'description':'{{desc}}'}" - + "\n#Output:"; + private static final String NAME_ALIAS_INSTRUCTION = "" + + "\n#Role: You are a professional data analyst specializing in metrics and dimensions." + + "\n#Task: You will be provided with metadata about a metric or dimension, please help " + + "generate a few aliases in the same language as its `fieldName`." + "\n#Rules:" + + "1. Please do not generate aliases like xxx1, xxx2, xxx3." + + "2. Please do not generate aliases that are the same as the original names of metrics/dimensions." + + "3. Please pay attention to the quality of the generated aliases and " + + "avoid creating aliases that look like test data." + + "4. Please output as a json string array." + + "\n#Metadata: {'table':'{{table}}', 'name':'{{name}}', 'type':'{{type}}', " + + "'field':'field', 'description':'{{desc}}'}" + "\n#Output:"; private static final String VALUE_ALIAS_INSTRUCTION = - "" - + "\n#Role: You are a professional data analyst." + "" + "\n#Role: You are a professional data analyst." + "\n#Task: You will be provided with a json array of dimension values," - + "please help generate a few aliases for each value." - + "\n#Rule:" + + "please help generate a few aliases for each value." + "\n#Rule:" + "1. ALWAYS output json array for each value." + "2. The aliases should be in the same language as its original value." - + "\n#Exemplar:" - + "Values: [\\\"qq_music\\\",\\\"kugou_music\\\"], " + + "\n#Exemplar:" + "Values: [\\\"qq_music\\\",\\\"kugou_music\\\"], " + "Output: {\\\"tran\\\":[\\\"qq音乐\\\",\\\"酷狗音乐\\\"]," + " \\\"alias\\\":{\\\"qq_music\\\":[\\\"q音\\\",\\\"qq音乐\\\"]," + " \\\"kugou_music\\\":[\\\"kugou\\\",\\\"酷狗\\\"]}}" + "\nValues: {{values}}, Output:"; - public String generateAlias( - String mockType, String name, String bizName, String table, String desc) { + public String generateAlias(String mockType, String name, String bizName, String table, + String desc) { Map variable = new HashMap<>(); variable.put("table", table); variable.put("name", name); @@ -88,8 +82,8 @@ public class AliasGenerateHelper { return response.content().text(); } - private static String extractString( - String targetString, String left, String right, Boolean exclusionFlag) { + private static String extractString(String targetString, String left, String right, + Boolean exclusionFlag) { if (targetString == null || left == null || right == null || exclusionFlag == null) { return targetString; } @@ -139,19 +133,18 @@ public class AliasGenerateHelper { } } BoundaryPattern[] patterns = { - // 不做任何匹配 - new BoundaryPattern(null, null, null), - // ```{"name":"Alice","age":25,"city":"NewYork"}``` - new BoundaryPattern("```", "```", true), - // ```json {"name":"Alice","age":25,"city":"NewYork"}``` - new BoundaryPattern("```json", "```", true), - // ```JSON {"name":"Alice","age":25,"city":"NewYork"}``` - new BoundaryPattern("```JSON", "```", true), - // {"name":"Alice","age":25,"city":"NewYork"} - new BoundaryPattern("{", "}", false), - // ["Alice", "Bob"] - new BoundaryPattern("[", "]", false) - }; + // 不做任何匹配 + new BoundaryPattern(null, null, null), + // ```{"name":"Alice","age":25,"city":"NewYork"}``` + new BoundaryPattern("```", "```", true), + // ```json {"name":"Alice","age":25,"city":"NewYork"}``` + new BoundaryPattern("```json", "```", true), + // ```JSON {"name":"Alice","age":25,"city":"NewYork"}``` + new BoundaryPattern("```JSON", "```", true), + // {"name":"Alice","age":25,"city":"NewYork"} + new BoundaryPattern("{", "}", false), + // ["Alice", "Bob"] + new BoundaryPattern("[", "]", false)}; for (BoundaryPattern pattern : patterns) { String extracted = extractString(aiMessage, pattern.left, pattern.right, pattern.exclusionFlag); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java index 37008e74d..8568fee8e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ChatWorkflowEngine.java @@ -62,10 +62,8 @@ public class ChatWorkflowEngine { parseResult.setErrorMsg("No semantic queries can be parsed out."); queryCtx.setChatWorkflowState(ChatWorkflowState.FINISHED); } else { - List parseInfos = - queryCtx.getCandidateQueries().stream() - .map(SemanticQuery::getParseInfo) - .collect(Collectors.toList()); + List parseInfos = queryCtx.getCandidateQueries().stream() + .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); parseResult.setSelectedParses(parseInfos); queryCtx.setChatWorkflowState(ChatWorkflowState.CORRECTING); } @@ -101,14 +99,11 @@ public class ChatWorkflowEngine { } private void performParsing(ChatQueryContext queryCtx) { - semanticParsers.forEach( - parser -> { - parser.parse(queryCtx); - log.debug( - "{} result:{}", - parser.getClass().getSimpleName(), - JsonUtil.toString(queryCtx)); - }); + semanticParsers.forEach(parser -> { + parser.parse(queryCtx); + log.debug("{} result:{}", parser.getClass().getSimpleName(), + JsonUtil.toString(queryCtx)); + }); } private void performCorrecting(ChatQueryContext queryCtx) { @@ -126,45 +121,38 @@ public class ChatWorkflowEngine { } private void performProcessing(ChatQueryContext queryCtx, ParseResp parseResult) { - resultProcessors.forEach( - processor -> { - processor.process(parseResult, queryCtx); - }); + resultProcessors.forEach(processor -> { + processor.process(parseResult, queryCtx); + }); } private void performTranslating(ChatQueryContext chatQueryContext) { - List semanticParseInfos = - chatQueryContext.getCandidateQueries().stream() - .map(SemanticQuery::getParseInfo) - .collect(Collectors.toList()); + List semanticParseInfos = chatQueryContext.getCandidateQueries().stream() + .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); - semanticParseInfos.forEach( - parseInfo -> { - try { - SemanticQuery semanticQuery = - QueryManager.createQuery(parseInfo.getQueryMode()); - if (Objects.isNull(semanticQuery)) { - return; - } - semanticQuery.setParseInfo(parseInfo); - SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); - SemanticLayerService queryService = - ContextUtils.getBean(SemanticLayerService.class); - SemanticTranslateResp explain = - queryService.translate( - semanticQueryReq, chatQueryContext.getUser()); - parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); + semanticParseInfos.forEach(parseInfo -> { + try { + SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); + if (Objects.isNull(semanticQuery)) { + return; + } + semanticQuery.setParseInfo(parseInfo); + SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); + SemanticLayerService queryService = + ContextUtils.getBean(SemanticLayerService.class); + SemanticTranslateResp explain = + queryService.translate(semanticQueryReq, chatQueryContext.getUser()); + parseInfo.getSqlInfo().setQuerySQL(explain.getQuerySQL()); - keyPipelineLog.info( - "SqlInfoProcessor results:\n" - + "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}", - StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()), - StringUtils.normalizeSpace( - parseInfo.getSqlInfo().getCorrectedS2SQL()), - StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL())); - } catch (Exception e) { - log.warn("get sql info failed:{}", parseInfo, e); - } - }); + keyPipelineLog.info( + "SqlInfoProcessor results:\n" + + "Parsed S2SQL: {}\nCorrected S2SQL: {}\nQuery SQL: {}", + StringUtils.normalizeSpace(parseInfo.getSqlInfo().getParsedS2SQL()), + StringUtils.normalizeSpace(parseInfo.getSqlInfo().getCorrectedS2SQL()), + StringUtils.normalizeSpace(parseInfo.getSqlInfo().getQuerySQL())); + } catch (Exception e) { + log.warn("get sql info failed:{}", parseInfo, e); + } + }); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ClassConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ClassConverter.java index 482f51a26..af782fa76 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ClassConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ClassConverter.java @@ -29,9 +29,7 @@ public class ClassConverter { private final DomainService domainService; private final TagObjectService tagObjectService; - public ClassConverter( - ClassRepository classRepository, - DomainService domainService, + public ClassConverter(ClassRepository classRepository, DomainService domainService, TagObjectService tagObjectService) { this.classRepository = classRepository; this.domainService = domainService; @@ -57,16 +55,13 @@ public class ClassConverter { return convert2RespInternal(classDO, idAndDomain, classFullPathMap); } - private ClassResp convert2RespInternal( - ClassDO classDO, - Map idAndDomain, + private ClassResp convert2RespInternal(ClassDO classDO, Map idAndDomain, Map classFullPathMap) { ClassResp classResp = new ClassResp(); BeanUtils.copyProperties(classDO, classResp); Long domainId = classResp.getDomainId(); - if (Objects.nonNull(idAndDomain) - && idAndDomain.containsKey(domainId) + if (Objects.nonNull(idAndDomain) && idAndDomain.containsKey(domainId) && Objects.nonNull(idAndDomain.get(domainId))) { classResp.setDomainName(idAndDomain.get(domainId).getName()); } @@ -94,9 +89,8 @@ public class ClassConverter { public Map getClassFullPathMap() { Map classFullPathMap = new HashMap<>(); List classDOList = classRepository.getAllClassDOList(); - Map classDOMap = - classDOList.stream() - .collect(Collectors.toMap(ClassDO::getId, a -> a, (k1, k2) -> k1)); + Map classDOMap = classDOList.stream() + .collect(Collectors.toMap(ClassDO::getId, a -> a, (k1, k2) -> k1)); for (ClassDO classDO : classDOList) { final Long domainId = classDO.getId(); StringBuilder fullPath = new StringBuilder(classDO.getBizName() + "/"); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java index 2322b2090..db3223699 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ComponentFactory.java @@ -27,8 +27,7 @@ public class ComponentFactory { } public static List getSchemaMappers() { - return CollectionUtils.isEmpty(schemaMappers) - ? init(SchemaMapper.class, schemaMappers) + return CollectionUtils.isEmpty(schemaMappers) ? init(SchemaMapper.class, schemaMappers) : schemaMappers; } @@ -49,15 +48,13 @@ public class ComponentFactory { } private static List init(Class factoryType, List list) { - list.addAll( - SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader())); + list.addAll(SpringFactoriesLoader.loadFactories(factoryType, + Thread.currentThread().getContextClassLoader())); return list; } private static T init(Class factoryType) { - return SpringFactoriesLoader.loadFactories( - factoryType, Thread.currentThread().getContextClassLoader()) - .get(0); + return SpringFactoriesLoader + .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java index c338d1927..94b71979a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DataSetSchemaBuilder.java @@ -26,15 +26,9 @@ public class DataSetSchemaBuilder { public static DataSetSchema build(DataSetSchemaResp resp) { DataSetSchema dataSetSchema = new DataSetSchema(); dataSetSchema.setQueryConfig(resp.getQueryConfig()); - SchemaElement dataSet = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .id(resp.getId()) - .name(resp.getName()) - .bizName(resp.getBizName()) - .type(SchemaElementType.DATASET) - .build(); + SchemaElement dataSet = SchemaElement.builder().dataSetId(resp.getId()) + .dataSetName(resp.getName()).id(resp.getId()).name(resp.getName()) + .bizName(resp.getBizName()).type(SchemaElementType.DATASET).build(); dataSetSchema.setDataSet(dataSet); dataSetSchema.setDatabaseType(resp.getDatabaseType()); @@ -68,21 +62,12 @@ public class DataSetSchemaBuilder { for (MetricSchemaResp metric : resp.getMetrics()) { List alias = SchemaItem.getAliasList(metric.getAlias()); if (metric.getIsTag() == 1) { - SchemaElement tagToAdd = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .model(metric.getModelId()) - .id(metric.getId()) - .name(metric.getName()) - .bizName(metric.getBizName()) - .type(SchemaElementType.TAG) - .useCnt(metric.getUseCnt()) - .alias(alias) - .defaultAgg(metric.getDefaultAgg()) - .isTag(metric.getIsTag()) - .description(metric.getDescription()) - .build(); + SchemaElement tagToAdd = SchemaElement.builder().dataSetId(resp.getId()) + .dataSetName(resp.getName()).model(metric.getModelId()).id(metric.getId()) + .name(metric.getName()).bizName(metric.getBizName()) + .type(SchemaElementType.TAG).useCnt(metric.getUseCnt()).alias(alias) + .defaultAgg(metric.getDefaultAgg()).isTag(metric.getIsTag()) + .description(metric.getDescription()).build(); tags.add(tagToAdd); } } @@ -103,21 +88,11 @@ public class DataSetSchemaBuilder { } } if (dim.getIsTag() == 1) { - SchemaElement tagToAdd = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .model(dim.getModelId()) - .id(dim.getId()) - .name(dim.getName()) - .bizName(dim.getBizName()) - .type(SchemaElementType.TAG) - .useCnt(dim.getUseCnt()) - .alias(alias) - .schemaValueMaps(schemaValueMaps) - .isTag(dim.getIsTag()) - .description(dim.getDescription()) - .build(); + SchemaElement tagToAdd = SchemaElement.builder().dataSetId(resp.getId()) + .dataSetName(resp.getName()).model(dim.getModelId()).id(dim.getId()) + .name(dim.getName()).bizName(dim.getBizName()).type(SchemaElementType.TAG) + .useCnt(dim.getUseCnt()).alias(alias).schemaValueMaps(schemaValueMaps) + .isTag(dim.getIsTag()).description(dim.getDescription()).build(); tags.add(tagToAdd); } } @@ -129,15 +104,9 @@ public class DataSetSchemaBuilder { if (Objects.isNull(dim)) { return null; } - return SchemaElement.builder() - .dataSetId(resp.getId()) - .model(dim.getModelId()) - .id(dim.getId()) - .name(dim.getName()) - .bizName(dim.getBizName()) - .type(SchemaElementType.ENTITY) - .useCnt(dim.getUseCnt()) - .alias(dim.getEntityAlias()) + return SchemaElement.builder().dataSetId(resp.getId()).model(dim.getModelId()) + .id(dim.getId()).name(dim.getName()).bizName(dim.getBizName()) + .type(SchemaElementType.ENTITY).useCnt(dim.getUseCnt()).alias(dim.getEntityAlias()) .build(); } @@ -154,21 +123,11 @@ public class DataSetSchemaBuilder { schemaValueMaps.add(schemaValueMap); } } - SchemaElement dimToAdd = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .model(dim.getModelId()) - .id(dim.getId()) - .name(dim.getName()) - .bizName(dim.getBizName()) - .useCnt(dim.getUseCnt()) - .alias(alias) - .schemaValueMaps(schemaValueMaps) - .isTag(dim.getIsTag()) - .description(dim.getDescription()) - .type(SchemaElementType.DIMENSION) - .build(); + SchemaElement dimToAdd = SchemaElement.builder().dataSetId(resp.getId()) + .dataSetName(resp.getName()).model(dim.getModelId()).id(dim.getId()) + .name(dim.getName()).bizName(dim.getBizName()).useCnt(dim.getUseCnt()) + .alias(alias).schemaValueMaps(schemaValueMaps).isTag(dim.getIsTag()) + .description(dim.getDescription()).type(SchemaElementType.DIMENSION).build(); dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TYPE, dim.getType()); if (dim.isTimeDimension()) { @@ -196,22 +155,12 @@ public class DataSetSchemaBuilder { } } } - SchemaElement dimValueToAdd = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .model(dim.getModelId()) - .id(dim.getId()) - .name(dim.getName()) - .bizName(dim.getBizName()) - .type(SchemaElementType.VALUE) - .useCnt(dim.getUseCnt()) - .alias( - new ArrayList<>( - Arrays.asList(dimValueAlias.toArray(new String[0])))) - .isTag(dim.getIsTag()) - .description(dim.getDescription()) - .build(); + SchemaElement dimValueToAdd = SchemaElement.builder().dataSetId(resp.getId()) + .dataSetName(resp.getName()).model(dim.getModelId()).id(dim.getId()) + .name(dim.getName()).bizName(dim.getBizName()).type(SchemaElementType.VALUE) + .useCnt(dim.getUseCnt()) + .alias(new ArrayList<>(Arrays.asList(dimValueAlias.toArray(new String[0])))) + .isTag(dim.getIsTag()).description(dim.getDescription()).build(); dimensionValues.add(dimValueToAdd); } return dimensionValues; @@ -223,23 +172,13 @@ public class DataSetSchemaBuilder { List alias = SchemaItem.getAliasList(metric.getAlias()); - SchemaElement metricToAdd = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .model(metric.getModelId()) - .id(metric.getId()) - .name(metric.getName()) - .bizName(metric.getBizName()) - .type(SchemaElementType.METRIC) - .useCnt(metric.getUseCnt()) - .alias(alias) - .relatedSchemaElements(getRelateSchemaElement(metric)) - .defaultAgg(metric.getDefaultAgg()) - .dataFormatType(metric.getDataFormatType()) - .isTag(metric.getIsTag()) - .description(metric.getDescription()) - .build(); + SchemaElement metricToAdd = SchemaElement.builder().dataSetId(resp.getId()) + .dataSetName(resp.getName()).model(metric.getModelId()).id(metric.getId()) + .name(metric.getName()).bizName(metric.getBizName()) + .type(SchemaElementType.METRIC).useCnt(metric.getUseCnt()).alias(alias) + .relatedSchemaElements(getRelateSchemaElement(metric)) + .defaultAgg(metric.getDefaultAgg()).dataFormatType(metric.getDataFormatType()) + .isTag(metric.getIsTag()).description(metric.getDescription()).build(); metrics.add(metricToAdd); } return metrics; @@ -250,18 +189,10 @@ public class DataSetSchemaBuilder { for (TermResp termResp : resp.getTermResps()) { List alias = termResp.getAlias(); SchemaElement metricToAdd = - SchemaElement.builder() - .dataSetId(resp.getId()) - .dataSetName(resp.getName()) - .model(-1L) - .id(termResp.getId()) - .name(termResp.getName()) - .bizName(termResp.getName()) - .type(SchemaElementType.TERM) - .useCnt(0L) - .alias(alias) - .description(termResp.getDescription()) - .build(); + SchemaElement.builder().dataSetId(resp.getId()).dataSetName(resp.getName()) + .model(-1L).id(termResp.getId()).name(termResp.getName()) + .bizName(termResp.getName()).type(SchemaElementType.TERM).useCnt(0L) + .alias(alias).description(termResp.getDescription()).build(); terms.add(metricToAdd); } return terms; @@ -274,26 +205,19 @@ public class DataSetSchemaBuilder { || CollectionUtils.isEmpty(relateDimension.getDrillDownDimensions())) { return Lists.newArrayList(); } - return relateDimension.getDrillDownDimensions().stream() - .map( - dimension -> { - RelatedSchemaElement relateSchemaElement = new RelatedSchemaElement(); - BeanUtils.copyProperties(dimension, relateSchemaElement); - return relateSchemaElement; - }) - .collect(Collectors.toList()); + return relateDimension.getDrillDownDimensions().stream().map(dimension -> { + RelatedSchemaElement relateSchemaElement = new RelatedSchemaElement(); + BeanUtils.copyProperties(dimension, relateSchemaElement); + return relateSchemaElement; + }).collect(Collectors.toList()); } - private static void setDefaultTimeFormat( - SchemaElement dimToAdd, - DimensionTimeTypeParams dimensionTimeTypeParams, - String timeFormat) { - if (null != dimensionTimeTypeParams - && TimeDimensionEnum.DAY - .name() - .equalsIgnoreCase(dimensionTimeTypeParams.getTimeGranularity())) { - dimToAdd.getExtInfo() - .put(DimensionConstants.DIMENSION_TIME_FORMAT, DateUtils.DEFAULT_DATE_FORMAT); + private static void setDefaultTimeFormat(SchemaElement dimToAdd, + DimensionTimeTypeParams dimensionTimeTypeParams, String timeFormat) { + if (null != dimensionTimeTypeParams && TimeDimensionEnum.DAY.name() + .equalsIgnoreCase(dimensionTimeTypeParams.getTimeGranularity())) { + dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, + DateUtils.DEFAULT_DATE_FORMAT); } else { dimToAdd.getExtInfo().put(DimensionConstants.DIMENSION_TIME_FORMAT, timeFormat); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java index 65b27d9b2..ec9abcdc2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DictUtils.java @@ -90,11 +90,8 @@ public class DictUtils { private final ModelService modelService; private final TagMetaService tagMetaService; - public DictUtils( - DimensionService dimensionService, - MetricService metricService, - SemanticLayerService queryService, - ModelService modelService, + public DictUtils(DimensionService dimensionService, MetricService metricService, + SemanticLayerService queryService, ModelService modelService, @Lazy TagMetaService tagMetaService) { this.dimensionService = dimensionService; this.metricService = metricService; @@ -104,13 +101,12 @@ public class DictUtils { } public String fetchDictFileName(DictItemResp dictItemResp) { - return String.format( - "dic_value_%d_%s_%s", - dictItemResp.getModelId(), dictItemResp.getType().name(), dictItemResp.getItemId()); + return String.format("dic_value_%d_%s_%s", dictItemResp.getModelId(), + dictItemResp.getType().name(), dictItemResp.getItemId()); } - public DictTaskDO generateDictTaskDO( - DictItemResp dictItemResp, User user, TaskStatusEnum status) { + public DictTaskDO generateDictTaskDO(DictItemResp dictItemResp, User user, + TaskStatusEnum status) { DictTaskDO taskDO = new DictTaskDO(); Date createAt = new Date(); String name = dictItemResp.fetchDictFileName(); @@ -141,14 +137,12 @@ public class DictUtils { public List dictDOList2Req(List dictConfDOList) { List dictItemReqList = new ArrayList<>(); - dictConfDOList.stream() - .forEach( - conf -> { - DictItemResp dictItemResp = dictDO2Req(conf); - if (Objects.nonNull(dictItemResp)) { - dictItemReqList.add(dictDO2Req(conf)); - } - }); + dictConfDOList.stream().forEach(conf -> { + DictItemResp dictItemResp = dictDO2Req(conf); + if (Objects.nonNull(dictItemResp)) { + dictItemReqList.add(dictDO2Req(conf)); + } + }); return dictItemReqList; } @@ -190,10 +184,8 @@ public class DictUtils { Map valueAndFrequencyPair = new HashMap<>(2000); for (Map line : semanticQueryResp.getResultList()) { - if (CollectionUtils.isEmpty(line) - || !line.containsKey(bizName) - || line.get(bizName) == null - || line.size() != 2) { + if (CollectionUtils.isEmpty(line) || !line.containsKey(bizName) + || line.get(bizName) == null || line.size() != 2) { continue; } String dimValue = line.get(bizName).toString(); @@ -218,38 +210,35 @@ public class DictUtils { } private void addWhiteValueLines(DictItemResp dictItemResp, List lines, String nature) { - if (Objects.isNull(dictItemResp) - || Objects.isNull(dictItemResp.getConfig()) + if (Objects.isNull(dictItemResp) || Objects.isNull(dictItemResp.getConfig()) || CollectionUtils.isEmpty(dictItemResp.getConfig().getWhiteList())) { return; } List whiteList = dictItemResp.getConfig().getWhiteList(); - whiteList.forEach( - white -> { - if (!StringUtils.isEmpty(white)) { - white = white.replace(SPACE, POUND); - } - lines.add(String.format("%s %s %s", white, nature, itemValueWhiteFrequency)); - }); + whiteList.forEach(white -> { + if (!StringUtils.isEmpty(white)) { + white = white.replace(SPACE, POUND); + } + lines.add(String.format("%s %s %s", white, nature, itemValueWhiteFrequency)); + }); } - private void constructDictLines( - Map valueAndFrequencyPair, List lines, String nature) { + private void constructDictLines(Map valueAndFrequencyPair, List lines, + String nature) { if (CollectionUtils.isEmpty(valueAndFrequencyPair)) { return; } - valueAndFrequencyPair.forEach( - (value, frequency) -> { - if (!StringUtils.isEmpty(value)) { - value = value.replace(SPACE, POUND); - } - lines.add(String.format("%s %s %s", value, nature, frequency)); - }); + valueAndFrequencyPair.forEach((value, frequency) -> { + if (!StringUtils.isEmpty(value)) { + value = value.replace(SPACE, POUND); + } + lines.add(String.format("%s %s %s", value, nature, frequency)); + }); } - private void mergeMultivaluedValue( - Map valueAndFrequencyPair, String dimValue, Long metric) { + private void mergeMultivaluedValue(Map valueAndFrequencyPair, String dimValue, + Long metric) { if (StringUtils.isEmpty(dimValue)) { return; } @@ -263,8 +252,7 @@ public class DictUtils { for (String value : tmp.keySet()) { long metricOld = - valueAndFrequencyPair.containsKey(value) - ? valueAndFrequencyPair.get(value) + valueAndFrequencyPair.containsKey(value) ? valueAndFrequencyPair.get(value) : 0L; valueAndFrequencyPair.put(value, metric + metricOld); } @@ -286,8 +274,7 @@ public class DictUtils { String where = StringUtils.isEmpty(whereStr) ? "" : "WHERE" + whereStr; ItemValueConfig config = dictItemResp.getConfig(); int limit = - (Objects.isNull(config) || Objects.isNull(config.getLimit())) - ? itemValueMaxCount + (Objects.isNull(config) || Objects.isNull(config.getLimit())) ? itemValueMaxCount : dictItemResp.getConfig().getLimit(); // todo 自定义指标 @@ -312,8 +299,7 @@ public class DictUtils { } private QuerySqlReq constructDimQueryReq(DictItemResp dictItemResp) { - if (Objects.nonNull(dictItemResp) - && Objects.nonNull(dictItemResp.getConfig()) + if (Objects.nonNull(dictItemResp) && Objects.nonNull(dictItemResp.getConfig()) && Objects.nonNull(dictItemResp.getConfig().getMetricId())) { // 查询默认指标 QueryStructReq queryStructReq = generateQueryStruct(dictItemResp); @@ -332,8 +318,7 @@ public class DictUtils { String where = StringUtils.isEmpty(whereStr) ? "" : "WHERE" + whereStr; ItemValueConfig config = dictItemResp.getConfig(); long limit = - (Objects.isNull(config) || Objects.isNull(config.getLimit())) - ? itemValueMaxCount + (Objects.isNull(config) || Objects.isNull(config.getLimit())) ? itemValueMaxCount : dictItemResp.getConfig().getLimit(); String sql = String.format(sqlPattern, bizName, where, bizName, limit); Set modelIds = new HashSet<>(); @@ -371,10 +356,8 @@ public class DictUtils { fillStructDateInfo(queryStructReq, dictItemResp); - int limit = - Objects.isNull(dictItemResp.getConfig().getLimit()) - ? itemValueMaxCount - : dictItemResp.getConfig().getLimit(); + int limit = Objects.isNull(dictItemResp.getConfig().getLimit()) ? itemValueMaxCount + : dictItemResp.getConfig().getLimit(); queryStructReq.setLimit(limit); queryStructReq.setNeedAuth(false); return queryStructReq; @@ -410,25 +393,18 @@ public class DictUtils { } } - private void fillStructDateBetween( - QueryStructReq queryStructReq, - ModelResp model, - Integer itemValueDateStart, - Integer itemValueDateEnd) { + private void fillStructDateBetween(QueryStructReq queryStructReq, ModelResp model, + Integer itemValueDateStart, Integer itemValueDateEnd) { if (Objects.nonNull(model)) { List timeDims = model.getTimeDimension(); if (!CollectionUtils.isEmpty(timeDims)) { DateConf dateConf = new DateConf(); dateConf.setDateMode(DateConf.DateMode.BETWEEN); String format = timeDims.get(0).getDateFormat(); - String start = - LocalDate.now() - .minusDays(itemValueDateStart) - .format(DateTimeFormatter.ofPattern(format)); - String end = - LocalDate.now() - .minusDays(itemValueDateEnd) - .format(DateTimeFormatter.ofPattern(format)); + String start = LocalDate.now().minusDays(itemValueDateStart) + .format(DateTimeFormatter.ofPattern(format)); + String end = LocalDate.now().minusDays(itemValueDateEnd) + .format(DateTimeFormatter.ofPattern(format)); dateConf.setStartDate(start); dateConf.setEndDate(end); queryStructReq.setDateInfo(dateConf); @@ -477,17 +453,12 @@ public class DictUtils { public String defaultDateFilter() { String format = itemValueDateFormat; - String start = - LocalDate.now() - .minusDays(itemValueDateStart) - .format(DateTimeFormatter.ofPattern(format)); - String end = - LocalDate.now() - .minusDays(itemValueDateEnd) - .format(DateTimeFormatter.ofPattern(format)); - return String.format( - "( %s >= '%s' and %s <= '%s' )", - TimeDimensionEnum.DAY.getName(), start, TimeDimensionEnum.DAY.getName(), end); + String start = LocalDate.now().minusDays(itemValueDateStart) + .format(DateTimeFormatter.ofPattern(format)); + String end = LocalDate.now().minusDays(itemValueDateEnd) + .format(DateTimeFormatter.ofPattern(format)); + return String.format("( %s >= '%s' and %s <= '%s' )", TimeDimensionEnum.DAY.getName(), + start, TimeDimensionEnum.DAY.getName(), end); } private String generateDictDateFilter(DictItemResp dictItemResp) { @@ -502,11 +473,8 @@ public class DictUtils { } // 静态日期 if (DateConf.DateMode.BETWEEN.equals(config.getDateConf().getDateMode())) { - return String.format( - "( %s >= '%s' and %s <= '%s' )", - TimeDimensionEnum.DAY.getName(), - config.getDateConf().getStartDate(), - TimeDimensionEnum.DAY.getName(), + return String.format("( %s >= '%s' and %s <= '%s' )", TimeDimensionEnum.DAY.getName(), + config.getDateConf().getStartDate(), TimeDimensionEnum.DAY.getName(), config.getDateConf().getEndDate()); } // 动态日期 @@ -527,18 +495,12 @@ public class DictUtils { dateFormat = itemValueDateFormat; } String start = - LocalDate.now() - .minusDays(dictItemResp.getConfig().getDateConf().getUnit()) + LocalDate.now().minusDays(dictItemResp.getConfig().getDateConf().getUnit()) .format(DateTimeFormatter.ofPattern(dateFormat)); - String end = - LocalDate.now() - .minusDays(0) - .format(DateTimeFormatter.ofPattern(dateFormat)); - return String.format( - "( %s > '%s' and %s <= '%s' )", - TimeDimensionEnum.DAY.getName(), - start, - TimeDimensionEnum.DAY.getName(), + String end = LocalDate.now().minusDays(0) + .format(DateTimeFormatter.ofPattern(dateFormat)); + return String.format("( %s > '%s' and %s <= '%s' )", + TimeDimensionEnum.DAY.getName(), start, TimeDimensionEnum.DAY.getName(), end); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DimensionConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DimensionConverter.java index 66a9fa200..5a8e3379f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DimensionConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DimensionConverter.java @@ -69,24 +69,20 @@ public class DimensionConverter { return dimensionDO; } - public static DimensionResp convert2DimensionResp( - DimensionDO dimensionDO, Map modelRespMap) { + public static DimensionResp convert2DimensionResp(DimensionDO dimensionDO, + Map modelRespMap) { DimensionResp dimensionResp = new DimensionResp(); BeanUtils.copyProperties(dimensionDO, dimensionResp); dimensionResp.setModelName( modelRespMap.getOrDefault(dimensionResp.getModelId(), new ModelResp()).getName()); - dimensionResp.setModelBizName( - modelRespMap - .getOrDefault(dimensionResp.getModelId(), new ModelResp()) - .getBizName()); + dimensionResp.setModelBizName(modelRespMap + .getOrDefault(dimensionResp.getModelId(), new ModelResp()).getBizName()); if (dimensionDO.getDefaultValues() != null) { dimensionResp.setDefaultValues( JSONObject.parseObject(dimensionDO.getDefaultValues(), List.class)); } - dimensionResp.setModelFilterSql( - modelRespMap - .getOrDefault(dimensionResp.getModelId(), new ModelResp()) - .getFilterSql()); + dimensionResp.setModelFilterSql(modelRespMap + .getOrDefault(dimensionResp.getModelId(), new ModelResp()).getFilterSql()); if (StringUtils.isNotEmpty(dimensionDO.getDimValueMaps())) { dimensionResp.setDimValueMaps( JsonUtil.toList(dimensionDO.getDimValueMaps(), DimValueMap.class)); @@ -98,9 +94,8 @@ public class DimensionConverter { dimensionResp.setExt(JSONObject.parseObject(dimensionDO.getExt(), Map.class)); } if (StringUtils.isNoneBlank(dimensionDO.getTypeParams())) { - dimensionResp.setTypeParams( - JSONObject.parseObject( - dimensionDO.getTypeParams(), DimensionTimeTypeParams.class)); + dimensionResp.setTypeParams(JSONObject.parseObject(dimensionDO.getTypeParams(), + DimensionTimeTypeParams.class)); } dimensionResp.setType(getType(dimensionDO.getType())); dimensionResp.setTypeEnum(TypeEnums.DIMENSION); @@ -122,15 +117,12 @@ public class DimensionConverter { } } - public static List filterByDataSet( - List dimensionResps, DataSetResp dataSetResp) { + public static List filterByDataSet(List dimensionResps, + DataSetResp dataSetResp) { return dimensionResps.stream() - .filter( - dimensionResp -> - dataSetResp.dimensionIds().contains(dimensionResp.getId()) - || dataSetResp - .getAllIncludeAllModels() - .contains(dimensionResp.getModelId())) + .filter(dimensionResp -> dataSetResp.dimensionIds().contains(dimensionResp.getId()) + || dataSetResp.getAllIncludeAllModels() + .contains(dimensionResp.getModelId())) .collect(Collectors.toList()); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DomainConvert.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DomainConvert.java index 25265fb32..9c766d74b 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DomainConvert.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/DomainConvert.java @@ -29,22 +29,14 @@ public class DomainConvert { DomainResp domainResp = new DomainResp(); BeanUtils.copyProperties(domainDO, domainResp); domainResp.setFullPath(domainFullPathMap.get(domainDO.getId())); - domainResp.setAdmins( - StringUtils.isBlank(domainDO.getAdmin()) - ? Lists.newArrayList() - : Arrays.asList(domainDO.getAdmin().split(","))); - domainResp.setAdminOrgs( - StringUtils.isBlank(domainDO.getAdminOrg()) - ? Lists.newArrayList() - : Arrays.asList(domainDO.getAdminOrg().split(","))); - domainResp.setViewers( - StringUtils.isBlank(domainDO.getViewer()) - ? Lists.newArrayList() - : Arrays.asList(domainDO.getViewer().split(","))); - domainResp.setViewOrgs( - StringUtils.isBlank(domainDO.getViewOrg()) - ? Lists.newArrayList() - : Arrays.asList(domainDO.getViewOrg().split(","))); + domainResp.setAdmins(StringUtils.isBlank(domainDO.getAdmin()) ? Lists.newArrayList() + : Arrays.asList(domainDO.getAdmin().split(","))); + domainResp.setAdminOrgs(StringUtils.isBlank(domainDO.getAdminOrg()) ? Lists.newArrayList() + : Arrays.asList(domainDO.getAdminOrg().split(","))); + domainResp.setViewers(StringUtils.isBlank(domainDO.getViewer()) ? Lists.newArrayList() + : Arrays.asList(domainDO.getViewer().split(","))); + domainResp.setViewOrgs(StringUtils.isBlank(domainDO.getViewOrg()) ? Lists.newArrayList() + : Arrays.asList(domainDO.getViewOrg().split(","))); return domainResp; } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricConverter.java index 23ed7322e..e53753818 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricConverter.java @@ -70,13 +70,13 @@ public class MetricConverter { return convert2MetricResp(metricDO, new HashMap<>(), Lists.newArrayList()); } - public static MetricResp convert2MetricResp( - MetricDO metricDO, Map modelMap, List collect) { + public static MetricResp convert2MetricResp(MetricDO metricDO, Map modelMap, + List collect) { MetricResp metricResp = new MetricResp(); BeanUtils.copyProperties(metricDO, metricResp); - metricResp.setDataFormat( - JSONObject.parseObject(metricDO.getDataFormat(), DataFormat.class)); + metricResp + .setDataFormat(JSONObject.parseObject(metricDO.getDataFormat(), DataFormat.class)); ModelResp modelResp = modelMap.get(metricDO.getModelId()); if (modelResp != null) { metricResp.setModelName(modelResp.getName()); @@ -97,17 +97,14 @@ public class MetricConverter { } metricResp.setTypeEnum(TypeEnums.METRIC); if (MetricDefineType.MEASURE.name().equalsIgnoreCase(metricDO.getDefineType())) { - metricResp.setMetricDefineByMeasureParams( - JSONObject.parseObject( - metricDO.getTypeParams(), MetricDefineByMeasureParams.class)); + metricResp.setMetricDefineByMeasureParams(JSONObject + .parseObject(metricDO.getTypeParams(), MetricDefineByMeasureParams.class)); } else if (MetricDefineType.METRIC.name().equalsIgnoreCase(metricDO.getDefineType())) { - metricResp.setMetricDefineByMetricParams( - JSONObject.parseObject( - metricDO.getTypeParams(), MetricDefineByMetricParams.class)); + metricResp.setMetricDefineByMetricParams(JSONObject + .parseObject(metricDO.getTypeParams(), MetricDefineByMetricParams.class)); } else if (MetricDefineType.FIELD.name().equalsIgnoreCase(metricDO.getDefineType())) { - metricResp.setMetricDefineByFieldParams( - JSONObject.parseObject( - metricDO.getTypeParams(), MetricDefineByFieldParams.class)); + metricResp.setMetricDefineByFieldParams(JSONObject.parseObject(metricDO.getTypeParams(), + MetricDefineByFieldParams.class)); } if (metricDO.getDefineType() != null) { metricResp.setMetricDefineType(MetricDefineType.valueOf(metricDO.getDefineType())); @@ -116,15 +113,11 @@ public class MetricConverter { return metricResp; } - public static List filterByDataSet( - List metricResps, DataSetResp dataSetResp) { + public static List filterByDataSet(List metricResps, + DataSetResp dataSetResp) { return metricResps.stream() - .filter( - metricResp -> - dataSetResp.metricIds().contains(metricResp.getId()) - || dataSetResp - .getAllIncludeAllModels() - .contains(metricResp.getModelId())) + .filter(metricResp -> dataSetResp.metricIds().contains(metricResp.getId()) + || dataSetResp.getAllIncludeAllModels().contains(metricResp.getModelId())) .collect(Collectors.toList()); } } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricDrillDownChecker.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricDrillDownChecker.java index 1c93afd6f..74470a188 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricDrillDownChecker.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/MetricDrillDownChecker.java @@ -27,7 +27,8 @@ import java.util.stream.Collectors; @Slf4j public class MetricDrillDownChecker { - @Autowired private MetricService metricService; + @Autowired + private MetricService metricService; public void checkQuery(QueryStatement queryStatement) { SemanticSchemaResp semanticSchemaResp = queryStatement.getSemanticSchemaResp(); @@ -54,12 +55,8 @@ public class MetricDrillDownChecker { getNecessaryDimensionMissing(necessaryDimensions, dimensionFields); if (!CollectionUtils.isEmpty(dimensionsMissing)) { String errMsg = - String.format( - "指标:%s 缺失必要下钻维度:%s", - metric.getName(), - dimensionsMissing.stream() - .map(DimensionResp::getName) - .collect(Collectors.toList())); + String.format("指标:%s 缺失必要下钻维度:%s", metric.getName(), dimensionsMissing + .stream().map(DimensionResp::getName).collect(Collectors.toList())); throw new InvalidArgumentException(errMsg); } } @@ -95,28 +92,18 @@ public class MetricDrillDownChecker { * To check whether the dimension can drill down the metric, eg: some descriptive dimensions are * not suitable as drill-down dimensions */ - private boolean checkDrillDownDimension( - String dimensionName, - List metricResps, + private boolean checkDrillDownDimension(String dimensionName, List metricResps, SemanticSchemaResp semanticSchemaResp) { if (CollectionUtils.isEmpty(metricResps)) { return true; } - List relateDimensions = - metricResps.stream() - .map(this::getDrillDownDimensions) - .filter( - drillDownDimensions -> - !CollectionUtils.isEmpty(drillDownDimensions)) - .map( - drillDownDimensions -> - drillDownDimensions.stream() - .map(DrillDownDimension::getDimensionId) - .collect(Collectors.toList())) - .flatMap(Collection::stream) - .map(id -> convertDimensionIdToBizName(id, semanticSchemaResp)) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + List relateDimensions = metricResps.stream().map(this::getDrillDownDimensions) + .filter(drillDownDimensions -> !CollectionUtils.isEmpty(drillDownDimensions)) + .map(drillDownDimensions -> drillDownDimensions.stream() + .map(DrillDownDimension::getDimensionId).collect(Collectors.toList())) + .flatMap(Collection::stream) + .map(id -> convertDimensionIdToBizName(id, semanticSchemaResp)) + .filter(Objects::nonNull).collect(Collectors.toList()); // if no metric has drill down dimension, return true if (CollectionUtils.isEmpty(relateDimensions)) { return true; @@ -125,8 +112,8 @@ public class MetricDrillDownChecker { return relateDimensions.contains(dimensionName); } - private List getNecessaryDimensions( - MetricSchemaResp metric, SemanticSchemaResp semanticSchemaResp) { + private List getNecessaryDimensions(MetricSchemaResp metric, + SemanticSchemaResp semanticSchemaResp) { if (metric == null) { return Lists.newArrayList(); } @@ -134,12 +121,9 @@ public class MetricDrillDownChecker { if (CollectionUtils.isEmpty(drillDownDimensions)) { return Lists.newArrayList(); } - return drillDownDimensions.stream() - .filter(DrillDownDimension::isNecessary) - .map(DrillDownDimension::getDimensionId) - .map(semanticSchemaResp::getDimension) - .filter(Objects::nonNull) - .collect(Collectors.toList()); + return drillDownDimensions.stream().filter(DrillDownDimension::isNecessary) + .map(DrillDownDimension::getDimensionId).map(semanticSchemaResp::getDimension) + .filter(Objects::nonNull).collect(Collectors.toList()); } private List getDimensionFields(List groupByFields, List whereFields) { @@ -153,8 +137,8 @@ public class MetricDrillDownChecker { return dimensionFields; } - private List getMetrics( - List metricFields, SemanticSchemaResp semanticSchemaResp) { + private List getMetrics(List metricFields, + SemanticSchemaResp semanticSchemaResp) { return semanticSchemaResp.getMetrics().stream() .filter(metricSchemaResp -> metricFields.contains(metricSchemaResp.getBizName())) .collect(Collectors.toList()); diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelClusterBuilder.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelClusterBuilder.java index ba6e17e20..61b10867d 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelClusterBuilder.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelClusterBuilder.java @@ -19,11 +19,8 @@ public class ModelClusterBuilder { public static Map buildModelClusters(List modelIds) { SchemaService schemaService = ContextUtils.getBean(SchemaService.class); List modelSchemaResps = schemaService.fetchModelSchemaResps(modelIds); - Map modelIdToModelSchema = - modelSchemaResps.stream() - .collect( - Collectors.toMap( - ModelSchemaResp::getId, value -> value, (k1, k2) -> k1)); + Map modelIdToModelSchema = modelSchemaResps.stream() + .collect(Collectors.toMap(ModelSchemaResp::getId, value -> value, (k1, k2) -> k1)); Set visited = new HashSet<>(); List> modelClusters = new ArrayList<>(); @@ -40,25 +37,17 @@ public class ModelClusterBuilder { .collect(Collectors.toMap(ModelCluster::getKey, value -> value, (k1, k2) -> k1)); } - private static ModelCluster getModelCluster( - Map modelIdToModelSchema, Set modelIds) { - boolean containsPartitionDimensions = - modelIds.stream() - .map(modelIdToModelSchema::get) - .filter(Objects::nonNull) - .anyMatch( - modelSchemaResp -> - CollectionUtils.isNotEmpty( - modelSchemaResp.getTimeDimension())); + private static ModelCluster getModelCluster(Map modelIdToModelSchema, + Set modelIds) { + boolean containsPartitionDimensions = modelIds.stream().map(modelIdToModelSchema::get) + .filter(Objects::nonNull).anyMatch(modelSchemaResp -> CollectionUtils + .isNotEmpty(modelSchemaResp.getTimeDimension())); return ModelCluster.build(modelIds, containsPartitionDimensions); } - private static void dfs( - ModelSchemaResp model, - Map modelMap, - Set visited, - Set modelCluster) { + private static void dfs(ModelSchemaResp model, Map modelMap, + Set visited, Set modelCluster) { visited.add(model.getId()); modelCluster.add(model.getId()); for (Long neighborId : model.getModelClusterSet()) { diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java index aeed7b174..42b1c660e 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/ModelConverter.java @@ -55,22 +55,14 @@ public class ModelConverter { public static ModelResp convert(ModelDO modelDO) { ModelResp modelResp = new ModelResp(); BeanUtils.copyProperties(modelDO, modelResp); - modelResp.setAdmins( - StringUtils.isBlank(modelDO.getAdmin()) - ? Lists.newArrayList() - : Arrays.asList(modelDO.getAdmin().split(","))); - modelResp.setAdminOrgs( - StringUtils.isBlank(modelDO.getAdminOrg()) - ? Lists.newArrayList() - : Arrays.asList(modelDO.getAdminOrg().split(","))); - modelResp.setViewers( - StringUtils.isBlank(modelDO.getViewer()) - ? Lists.newArrayList() - : Arrays.asList(modelDO.getViewer().split(","))); - modelResp.setViewOrgs( - StringUtils.isBlank(modelDO.getViewOrg()) - ? Lists.newArrayList() - : Arrays.asList(modelDO.getViewOrg().split(","))); + modelResp.setAdmins(StringUtils.isBlank(modelDO.getAdmin()) ? Lists.newArrayList() + : Arrays.asList(modelDO.getAdmin().split(","))); + modelResp.setAdminOrgs(StringUtils.isBlank(modelDO.getAdminOrg()) ? Lists.newArrayList() + : Arrays.asList(modelDO.getAdminOrg().split(","))); + modelResp.setViewers(StringUtils.isBlank(modelDO.getViewer()) ? Lists.newArrayList() + : Arrays.asList(modelDO.getViewer().split(","))); + modelResp.setViewOrgs(StringUtils.isBlank(modelDO.getViewOrg()) ? Lists.newArrayList() + : Arrays.asList(modelDO.getViewOrg().split(","))); modelResp.setDrillDownDimensions( JsonUtil.toList(modelDO.getDrillDownDimensions(), DrillDownDimension.class)); modelResp.setModelDetail(JsonUtil.toObject(modelDO.getModelDetail(), ModelDetail.class)); @@ -129,8 +121,8 @@ public class ModelConverter { dimensionReq.setModelId(modelDO.getId()); dimensionReq.setExpr(dim.getBizName()); dimensionReq.setType(dim.getType()); - dimensionReq.setDescription( - Objects.isNull(dim.getDescription()) ? "" : dim.getDescription()); + dimensionReq + .setDescription(Objects.isNull(dim.getDescription()) ? "" : dim.getDescription()); dimensionReq.setIsTag(dim.getIsTag()); dimensionReq.setTypeParams(dim.getTypeParams()); return dimensionReq; @@ -189,8 +181,7 @@ public class ModelConverter { if (CollectionUtils.isEmpty(modelDetail.getDimensions())) { return Lists.newArrayList(); } - return modelDetail.getDimensions().stream() - .filter(ModelConverter::isCreateDimension) + return modelDetail.getDimensions().stream().filter(ModelConverter::isCreateDimension) .collect(Collectors.toList()); } @@ -198,8 +189,7 @@ public class ModelConverter { if (CollectionUtils.isEmpty(modelDetail.getIdentifiers())) { return Lists.newArrayList(); } - return modelDetail.getIdentifiers().stream() - .filter(ModelConverter::isCreateDimension) + return modelDetail.getIdentifiers().stream().filter(ModelConverter::isCreateDimension) .collect(Collectors.toList()); } @@ -207,8 +197,7 @@ public class ModelConverter { if (CollectionUtils.isEmpty(modelDetail.getMeasures())) { return Lists.newArrayList(); } - return modelDetail.getMeasures().stream() - .filter(ModelConverter::isCreateMetric) + return modelDetail.getMeasures().stream().filter(ModelConverter::isCreateMetric) .collect(Collectors.toList()); } @@ -218,20 +207,15 @@ public class ModelConverter { JSONObject.parseObject(modelDO.getModelDetail(), ModelDetail.class); List dims = getDimToCreateDimension(modelDetail); if (!CollectionUtils.isEmpty(dims)) { - dimensionReqs = - dims.stream() - .filter(dim -> StringUtils.isNotBlank(dim.getName())) - .map(dim -> convert(dim, modelDO)) - .collect(Collectors.toList()); + dimensionReqs = dims.stream().filter(dim -> StringUtils.isNotBlank(dim.getName())) + .map(dim -> convert(dim, modelDO)).collect(Collectors.toList()); } List identifies = getIdentityToCreateDimension(modelDetail); if (CollectionUtils.isEmpty(identifies)) { return dimensionReqs; } - dimensionReqs.addAll( - identifies.stream() - .map(identify -> convert(identify, modelDO)) - .collect(Collectors.toList())); + dimensionReqs.addAll(identifies.stream().map(identify -> convert(identify, modelDO)) + .collect(Collectors.toList())); return dimensionReqs; } @@ -242,8 +226,7 @@ public class ModelConverter { if (CollectionUtils.isEmpty(measures)) { return Lists.newArrayList(); } - return measures.stream() - .map(measure -> convert(measure, modelDO)) + return measures.stream().map(measure -> convert(measure, modelDO)) .collect(Collectors.toList()); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java index 823c0750b..7bdbc1fb2 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryReqConverter.java @@ -51,9 +51,11 @@ import java.util.stream.Stream; @Slf4j public class QueryReqConverter { - @Autowired private QueryStructUtils queryStructUtils; + @Autowired + private QueryStructUtils queryStructUtils; - @Autowired private SqlGenerateUtils sqlGenerateUtils; + @Autowired + private SqlGenerateUtils sqlGenerateUtils; public QueryStatement convert(QuerySqlReq querySQLReq, SemanticSchemaResp semanticSchemaResp) throws Exception { @@ -93,17 +95,12 @@ public class QueryReqConverter { // if metric empty , fill model default if (CollectionUtils.isEmpty(metricTable.getMetrics())) { metricTable.setMetrics(new ArrayList<>()); - metricTable - .getMetrics() - .add( - sqlGenerateUtils.generateInternalMetricName( - getDefaultModel( - semanticSchemaResp, metricTable.getDimensions()))); + metricTable.getMetrics().add(sqlGenerateUtils.generateInternalMetricName( + getDefaultModel(semanticSchemaResp, metricTable.getDimensions()))); } else { - queryStructReq.setAggregators( - metricTable.getMetrics().stream() - .map(m -> new Aggregator(m, AggOperatorEnum.UNKNOWN)) - .collect(Collectors.toList())); + queryStructReq.setAggregators(metricTable.getMetrics().stream() + .map(m -> new Aggregator(m, AggOperatorEnum.UNKNOWN)) + .collect(Collectors.toList())); } AggOption aggOption = getAggOption(querySQLReq, metricSchemas); metricTable.setAggOption(aggOption); @@ -115,8 +112,8 @@ public class QueryReqConverter { result.setTables(tables); DatabaseResp database = semanticSchemaResp.getDatabaseResp(); - if (!sqlGenerateUtils.isSupportWith( - EngineType.fromString(database.getType().toUpperCase()), database.getVersion())) { + if (!sqlGenerateUtils.isSupportWith(EngineType.fromString(database.getType().toUpperCase()), + database.getVersion())) { result.setSupportWith(false); result.setWithAlias(false); } @@ -160,18 +157,13 @@ public class QueryReqConverter { if (databaseReq.isInnerLayerNative()) { return AggOption.NATIVE; } - if (SqlSelectHelper.hasSubSelect(sql) - || SqlSelectHelper.hasWith(sql) + if (SqlSelectHelper.hasSubSelect(sql) || SqlSelectHelper.hasWith(sql) || SqlSelectHelper.hasGroupBy(sql)) { return AggOption.OUTER; } - long defaultAggNullCnt = - metricSchemas.stream() - .filter( - m -> - Objects.isNull(m.getDefaultAgg()) - || StringUtils.isBlank(m.getDefaultAgg())) - .count(); + long defaultAggNullCnt = metricSchemas.stream().filter( + m -> Objects.isNull(m.getDefaultAgg()) || StringUtils.isBlank(m.getDefaultAgg())) + .count(); if (defaultAggNullCnt > 0) { log.debug("getAggOption find null defaultAgg metric set to NATIVE"); return AggOption.OUTER; @@ -179,32 +171,25 @@ public class QueryReqConverter { return AggOption.DEFAULT; } - private void convertNameToBizName( - QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) { + private void convertNameToBizName(QuerySqlReq querySqlReq, + SemanticSchemaResp semanticSchemaResp) { Map fieldNameToBizNameMap = getFieldNameToBizNameMap(semanticSchemaResp); String sql = querySqlReq.getSql(); - log.debug( - "dataSetId:{},convert name to bizName before:{}", querySqlReq.getDataSetId(), sql); + log.debug("dataSetId:{},convert name to bizName before:{}", querySqlReq.getDataSetId(), + sql); String replaceFields = SqlReplaceHelper.replaceFields(sql, fieldNameToBizNameMap, true); - log.debug( - "dataSetId:{},convert name to bizName after:{}", - querySqlReq.getDataSetId(), + log.debug("dataSetId:{},convert name to bizName after:{}", querySqlReq.getDataSetId(), replaceFields); querySqlReq.setSql(replaceFields); } - private Set getDimensions( - SemanticSchemaResp semanticSchemaResp, List allFields) { - Map dimensionLowerToNameMap = - semanticSchemaResp.getDimensions().stream() - .collect( - Collectors.toMap( - entry -> entry.getBizName().toLowerCase(), - SchemaItem::getBizName, - (k1, k2) -> k1)); - Map internalLowerToNameMap = - QueryStructUtils.internalCols.stream() - .collect(Collectors.toMap(String::toLowerCase, a -> a)); + private Set getDimensions(SemanticSchemaResp semanticSchemaResp, + List allFields) { + Map dimensionLowerToNameMap = semanticSchemaResp.getDimensions().stream() + .collect(Collectors.toMap(entry -> entry.getBizName().toLowerCase(), + SchemaItem::getBizName, (k1, k2) -> k1)); + Map internalLowerToNameMap = QueryStructUtils.internalCols.stream() + .collect(Collectors.toMap(String::toLowerCase, a -> a)); dimensionLowerToNameMap.putAll(internalLowerToNameMap); return allFields.stream() .filter(entry -> dimensionLowerToNameMap.containsKey(entry.toLowerCase())) @@ -212,21 +197,19 @@ public class QueryReqConverter { .collect(Collectors.toSet()); } - private List getMetrics( - SemanticSchemaResp semanticSchemaResp, List allFields) { + private List getMetrics(SemanticSchemaResp semanticSchemaResp, + List allFields) { Map metricLowerToNameMap = - semanticSchemaResp.getMetrics().stream() - .collect( - Collectors.toMap( - entry -> entry.getBizName().toLowerCase(), entry -> entry)); + semanticSchemaResp.getMetrics().stream().collect(Collectors + .toMap(entry -> entry.getBizName().toLowerCase(), entry -> entry)); return allFields.stream() .filter(entry -> metricLowerToNameMap.containsKey(entry.toLowerCase())) .map(entry -> metricLowerToNameMap.get(entry.toLowerCase())) .collect(Collectors.toList()); } - private void functionNameCorrector( - QuerySqlReq databaseReq, SemanticSchemaResp semanticSchemaResp) { + private void functionNameCorrector(QuerySqlReq databaseReq, + SemanticSchemaResp semanticSchemaResp) { DatabaseResp database = semanticSchemaResp.getDatabaseResp(); if (Objects.isNull(database) || Objects.isNull(database.getType())) { return; @@ -242,25 +225,13 @@ public class QueryReqConverter { protected Map getFieldNameToBizNameMap(SemanticSchemaResp semanticSchemaResp) { // support fieldName and field alias to bizName - Map dimensionResults = - semanticSchemaResp.getDimensions().stream() - .flatMap( - entry -> - getPairStream( - entry.getAlias(), - entry.getName(), - entry.getBizName())) - .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); + Map dimensionResults = semanticSchemaResp.getDimensions().stream().flatMap( + entry -> getPairStream(entry.getAlias(), entry.getName(), entry.getBizName())) + .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); - Map metricResults = - semanticSchemaResp.getMetrics().stream() - .flatMap( - entry -> - getPairStream( - entry.getAlias(), - entry.getName(), - entry.getBizName())) - .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); + Map metricResults = semanticSchemaResp.getMetrics().stream().flatMap( + entry -> getPairStream(entry.getAlias(), entry.getName(), entry.getBizName())) + .collect(Collectors.toMap(Pair::getLeft, Pair::getRight, (k1, k2) -> k1)); dimensionResults.putAll(TimeDimensionEnum.getChNameToNameMap()); dimensionResults.putAll(TimeDimensionEnum.getNameToNameMap()); @@ -268,8 +239,8 @@ public class QueryReqConverter { return dimensionResults; } - private Stream> getPairStream( - String aliasStr, String name, String bizName) { + private Stream> getPairStream(String aliasStr, String name, + String bizName) { Set> elements = new HashSet<>(); elements.add(Pair.of(name, bizName)); if (StringUtils.isNotBlank(aliasStr)) { @@ -283,9 +254,8 @@ public class QueryReqConverter { public void correctTableName(QuerySqlReq querySqlReq) { String sql = querySqlReq.getSql(); - sql = - SqlReplaceHelper.replaceTable( - sql, Constants.TABLE_PREFIX + querySqlReq.getDataSetId()); + sql = SqlReplaceHelper.replaceTable(sql, + Constants.TABLE_PREFIX + querySqlReq.getDataSetId()); log.debug("correctTableName after:{}", sql); querySqlReq.setSql(sql); } @@ -299,21 +269,14 @@ public class QueryReqConverter { return queryType; } - private void generateDerivedMetric( - SemanticSchemaResp semanticSchemaResp, - AggOption aggOption, + private void generateDerivedMetric(SemanticSchemaResp semanticSchemaResp, AggOption aggOption, DataSetQueryParam viewQueryParam) { String sql = viewQueryParam.getSql(); for (MetricTable metricTable : viewQueryParam.getTables()) { Set measures = new HashSet<>(); Map replaces = new HashMap<>(); - generateDerivedMetric( - semanticSchemaResp, - aggOption, - metricTable.getMetrics(), - metricTable.getDimensions(), - measures, - replaces); + generateDerivedMetric(semanticSchemaResp, aggOption, metricTable.getMetrics(), + metricTable.getDimensions(), measures, replaces); if (!CollectionUtils.isEmpty(replaces)) { // metricTable sql use measures replace metric sql = SqlReplaceHelper.replaceSqlByExpression(sql, replaces); @@ -324,72 +287,46 @@ public class QueryReqConverter { } else { // empty measure , fill default metricTable.setMetrics(new ArrayList<>()); - metricTable - .getMetrics() - .add( - sqlGenerateUtils.generateInternalMetricName( - getDefaultModel( - semanticSchemaResp, - metricTable.getDimensions()))); + metricTable.getMetrics().add(sqlGenerateUtils.generateInternalMetricName( + getDefaultModel(semanticSchemaResp, metricTable.getDimensions()))); } } } viewQueryParam.setSql(sql); } - private void generateDerivedMetric( - SemanticSchemaResp semanticSchemaResp, - AggOption aggOption, - List metrics, - List dimensions, - Set measures, + private void generateDerivedMetric(SemanticSchemaResp semanticSchemaResp, AggOption aggOption, + List metrics, List dimensions, Set measures, Map replaces) { List metricResps = semanticSchemaResp.getMetrics(); List dimensionResps = semanticSchemaResp.getDimensions(); // check metrics has derived - if (!metricResps.stream() - .anyMatch( - m -> - metrics.contains(m.getBizName()) - && MetricType.isDerived( - m.getMetricDefineType(), - m.getMetricDefineByMeasureParams()))) { + if (!metricResps.stream().anyMatch(m -> metrics.contains(m.getBizName()) && MetricType + .isDerived(m.getMetricDefineType(), m.getMetricDefineByMeasureParams()))) { return; } log.debug("begin to generateDerivedMetric {} [{}]", aggOption, metrics); Set allFields = new HashSet<>(); Map allMeasures = new HashMap<>(); - semanticSchemaResp - .getModelResps() - .forEach( - modelResp -> { - allFields.addAll(modelResp.getFieldList()); - if (Objects.nonNull(modelResp.getModelDetail().getMeasures())) { - modelResp.getModelDetail().getMeasures().stream() - .forEach(mm -> allMeasures.put(mm.getBizName(), mm)); - } - }); + semanticSchemaResp.getModelResps().forEach(modelResp -> { + allFields.addAll(modelResp.getFieldList()); + if (Objects.nonNull(modelResp.getModelDetail().getMeasures())) { + modelResp.getModelDetail().getMeasures().stream() + .forEach(mm -> allMeasures.put(mm.getBizName(), mm)); + } + }); Set deriveDimension = new HashSet<>(); Set deriveMetric = new HashSet<>(); Set visitedMetric = new HashSet<>(); if (!CollectionUtils.isEmpty(metricResps)) { for (MetricResp metricResp : metricResps) { if (metrics.contains(metricResp.getBizName())) { - if (MetricType.isDerived( - metricResp.getMetricDefineType(), + if (MetricType.isDerived(metricResp.getMetricDefineType(), metricResp.getMetricDefineByMeasureParams())) { - String expr = - sqlGenerateUtils.generateDerivedMetric( - metricResps, - allFields, - allMeasures, - dimensionResps, - sqlGenerateUtils.getExpr(metricResp), - metricResp.getMetricDefineType(), - aggOption, - visitedMetric, - deriveMetric, - deriveDimension); + String expr = sqlGenerateUtils.generateDerivedMetric(metricResps, allFields, + allMeasures, dimensionResps, sqlGenerateUtils.getExpr(metricResp), + metricResp.getMetricDefineType(), aggOption, visitedMetric, + deriveMetric, deriveDimension); replaces.put(metricResp.getBizName(), expr); log.debug("derived metric {}->{}", metricResp.getBizName(), expr); } else { @@ -399,8 +336,7 @@ public class QueryReqConverter { } } measures.addAll(deriveMetric); - deriveDimension.stream() - .filter(d -> !dimensions.contains(d)) + deriveDimension.stream().filter(d -> !dimensions.contains(d)) .forEach(d -> dimensions.add(d)); } @@ -408,17 +344,12 @@ public class QueryReqConverter { if (!CollectionUtils.isEmpty(dimensions)) { Map modelMatchCnt = new HashMap<>(); for (ModelResp modelResp : semanticSchemaResp.getModelResps()) { - modelMatchCnt.put( - modelResp.getBizName(), - modelResp.getModelDetail().getDimensions().stream() - .filter(d -> dimensions.contains(d.getBizName())) - .count()); + modelMatchCnt.put(modelResp.getBizName(), modelResp.getModelDetail().getDimensions() + .stream().filter(d -> dimensions.contains(d.getBizName())).count()); } return modelMatchCnt.entrySet().stream() .sorted(Map.Entry.comparingByValue(Comparator.reverseOrder())) - .map(m -> m.getKey()) - .findFirst() - .orElse(""); + .map(m -> m.getKey()).findFirst().orElse(""); } return semanticSchemaResp.getModelResps().get(0).getBizName(); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryRuleConverter.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryRuleConverter.java index 706ab6d5f..74bb24b5c 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryRuleConverter.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryRuleConverter.java @@ -21,10 +21,8 @@ public class QueryRuleConverter { BeanUtils.copyProperties(queryRuleReq, queryRuleDO); queryRuleDO.setRuleType(queryRuleReq.getRuleType().name()); queryRuleDO.setRule(JsonUtil.toString(queryRuleReq.getRule())); - queryRuleDO.setAction( - Objects.isNull(queryRuleReq.getAction()) - ? "" - : JsonUtil.toString(queryRuleReq.getAction())); + queryRuleDO.setAction(Objects.isNull(queryRuleReq.getAction()) ? "" + : JsonUtil.toString(queryRuleReq.getAction())); queryRuleDO.setExt(JsonUtil.toString(queryRuleReq.getExt())); return queryRuleDO; @@ -35,10 +33,8 @@ public class QueryRuleConverter { BeanUtils.copyProperties(queryRuleDO, queryRuleResp); queryRuleResp.setRuleType(QueryRuleType.valueOf(queryRuleDO.getRuleType())); queryRuleResp.setRule(JsonUtil.toObject(queryRuleDO.getRule(), RuleInfo.class)); - queryRuleResp.setAction( - StringUtils.isEmpty(queryRuleDO.getAction()) - ? new ActionInfo() - : JsonUtil.toObject(queryRuleDO.getAction(), ActionInfo.class)); + queryRuleResp.setAction(StringUtils.isEmpty(queryRuleDO.getAction()) ? new ActionInfo() + : JsonUtil.toObject(queryRuleDO.getAction(), ActionInfo.class)); queryRuleResp.setExt(JsonUtil.toMap(queryRuleDO.getExt(), String.class, String.class)); return queryRuleResp; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java index edd43cbef..21fb67cd1 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryStructUtils.java @@ -54,9 +54,7 @@ public class QueryStructUtils { private final SqlFilterUtils sqlFilterUtils; private final SchemaService schemaService; - public QueryStructUtils( - DateModeUtils dateModeUtils, - SqlFilterUtils sqlFilterUtils, + public QueryStructUtils(DateModeUtils dateModeUtils, SqlFilterUtils sqlFilterUtils, SchemaService schemaService) { this.dateModeUtils = dateModeUtils; @@ -121,53 +119,43 @@ public class QueryStructUtils { return new HashSet<>(SqlSelectHelper.getAllSelectFields(querySqlReq.getSql())); } - public Set getModelIdsFromStruct( - QueryStructReq queryStructReq, SemanticSchemaResp semanticSchemaResp) { + public Set getModelIdsFromStruct(QueryStructReq queryStructReq, + SemanticSchemaResp semanticSchemaResp) { Set modelIds = Sets.newHashSet(); Set bizNameFromStruct = getBizNameFromStruct(queryStructReq); - modelIds.addAll( - semanticSchemaResp.getMetrics().stream() - .filter(metric -> bizNameFromStruct.contains(metric.getBizName())) - .map(MetricResp::getModelId) - .collect(Collectors.toSet())); - modelIds.addAll( - semanticSchemaResp.getDimensions().stream() - .filter(dimension -> bizNameFromStruct.contains(dimension.getBizName())) - .map(DimensionResp::getModelId) - .collect(Collectors.toList())); + modelIds.addAll(semanticSchemaResp.getMetrics().stream() + .filter(metric -> bizNameFromStruct.contains(metric.getBizName())) + .map(MetricResp::getModelId).collect(Collectors.toSet())); + modelIds.addAll(semanticSchemaResp.getDimensions().stream() + .filter(dimension -> bizNameFromStruct.contains(dimension.getBizName())) + .map(DimensionResp::getModelId).collect(Collectors.toList())); return modelIds; } - private List getMetricsFromSql( - QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) { + private List getMetricsFromSql(QuerySqlReq querySqlReq, + SemanticSchemaResp semanticSchemaResp) { Set resNameSet = getResName(querySqlReq); if (semanticSchemaResp != null) { - return semanticSchemaResp.getMetrics().stream() - .filter( - m -> - resNameSet.contains(m.getName()) - || resNameSet.contains(m.getBizName())) + return semanticSchemaResp.getMetrics().stream().filter( + m -> resNameSet.contains(m.getName()) || resNameSet.contains(m.getBizName())) .collect(Collectors.toList()); } return Lists.newArrayList(); } - private List getDimensionsFromSql( - QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) { + private List getDimensionsFromSql(QuerySqlReq querySqlReq, + SemanticSchemaResp semanticSchemaResp) { Set resNameSet = getResName(querySqlReq); if (semanticSchemaResp != null) { - return semanticSchemaResp.getDimensions().stream() - .filter( - m -> - resNameSet.contains(m.getName()) - || resNameSet.contains(m.getBizName())) + return semanticSchemaResp.getDimensions().stream().filter( + m -> resNameSet.contains(m.getName()) || resNameSet.contains(m.getBizName())) .collect(Collectors.toList()); } return Lists.newArrayList(); } - public Set getModelIdFromSql( - QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) { + public Set getModelIdFromSql(QuerySqlReq querySqlReq, + SemanticSchemaResp semanticSchemaResp) { Set modelIds = Sets.newHashSet(); List dimensions = getDimensionsFromSql(querySqlReq, semanticSchemaResp); List metrics = getMetricsFromSql(querySqlReq, semanticSchemaResp); @@ -177,8 +165,8 @@ public class QueryStructUtils { return modelIds; } - public Set getBizNameFromSql( - QuerySqlReq querySqlReq, SemanticSchemaResp semanticSchemaResp) { + public Set getBizNameFromSql(QuerySqlReq querySqlReq, + SemanticSchemaResp semanticSchemaResp) { Set bizNames = Sets.newHashSet(); List dimensions = getDimensionsFromSql(querySqlReq, semanticSchemaResp); List metrics = getMetricsFromSql(querySqlReq, semanticSchemaResp); @@ -191,10 +179,9 @@ public class QueryStructUtils { public ItemDateResp getItemDateResp(QueryStructReq queryStructCmd) { List dimensionIds = getDimensionIds(queryStructCmd); List metricIds = getMetricIds(queryStructCmd); - ItemDateResp dateDate = - schemaService.getItemDate( - new ItemDateFilter(dimensionIds, TypeEnums.DIMENSION.name()), - new ItemDateFilter(metricIds, TypeEnums.METRIC.name())); + ItemDateResp dateDate = schemaService.getItemDate( + new ItemDateFilter(dimensionIds, TypeEnums.DIMENSION.name()), + new ItemDateFilter(metricIds, TypeEnums.METRIC.name())); return dateDate; } @@ -212,17 +199,14 @@ public class QueryStructUtils { case BETWEEN: return Triple.of(dateInfo, dateConf.getStartDate(), dateConf.getEndDate()); case LIST: - return Triple.of( - dateInfo, - Collections.min(dateConf.getDateList()), + return Triple.of(dateInfo, Collections.min(dateConf.getDateList()), Collections.max(dateConf.getDateList())); case RECENT: ItemDateResp dateDate = getItemDateResp(queryStructCmd); LocalDate dateMax = LocalDate.now().minusDays(1); LocalDate dateMin = dateMax.minusDays(dateConf.getUnit() - 1); if (Objects.isNull(dateDate)) { - return Triple.of( - dateInfo, + return Triple.of(dateInfo, dateMin.format(DateTimeFormatter.ofPattern(DAY_FORMAT)), dateMax.format(DateTimeFormatter.ofPattern(DAY_FORMAT))); } @@ -240,11 +224,8 @@ public class QueryStructUtils { dateModeUtils.recentMonth(dateDate, dateConf); Optional minBegins = rets.stream().map(i -> i.left).sorted().findFirst(); - Optional maxBegins = - rets.stream() - .map(i -> i.right) - .sorted(Comparator.reverseOrder()) - .findFirst(); + Optional maxBegins = rets.stream().map(i -> i.right) + .sorted(Comparator.reverseOrder()).findFirst(); if (minBegins.isPresent() && maxBegins.isPresent()) { return Triple.of(dateInfo, minBegins.get(), maxBegins.get()); } diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryUtils.java index 8c71ad51c..597b59c9a 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/QueryUtils.java @@ -37,8 +37,8 @@ public class QueryUtils { @Value("${s2.query-optimizer.enable:true}") private Boolean optimizeEnable; - public void populateQueryColumns( - SemanticQueryResp semanticQueryResp, SemanticSchemaResp semanticSchemaResp) { + public void populateQueryColumns(SemanticQueryResp semanticQueryResp, + SemanticSchemaResp semanticSchemaResp) { Map metricRespMap = createMetricRespMap(semanticSchemaResp); Map namePair = new HashMap<>(); Map nameTypePair = new HashMap<>(); @@ -53,36 +53,24 @@ public class QueryUtils { .collect(Collectors.toMap(MetricResp::getBizName, a -> a, (k1, k2) -> k1)); } - private void populateNamePairs( - SemanticSchemaResp semanticSchemaResp, - Map namePair, - Map nameTypePair) { + private void populateNamePairs(SemanticSchemaResp semanticSchemaResp, + Map namePair, Map nameTypePair) { for (TimeDimensionEnum timeDimensionEnum : TimeDimensionEnum.values()) { namePair.put(timeDimensionEnum.getName(), "date"); nameTypePair.put(timeDimensionEnum.getName(), "DATE"); } - semanticSchemaResp - .getMetrics() - .forEach( - metricDesc -> { - namePair.put(metricDesc.getBizName(), metricDesc.getName()); - nameTypePair.put(metricDesc.getBizName(), SemanticType.NUMBER.name()); - }); - semanticSchemaResp - .getDimensions() - .forEach( - dimensionDesc -> { - namePair.put(dimensionDesc.getBizName(), dimensionDesc.getName()); - nameTypePair.put( - dimensionDesc.getBizName(), dimensionDesc.getSemanticType()); - }); + semanticSchemaResp.getMetrics().forEach(metricDesc -> { + namePair.put(metricDesc.getBizName(), metricDesc.getName()); + nameTypePair.put(metricDesc.getBizName(), SemanticType.NUMBER.name()); + }); + semanticSchemaResp.getDimensions().forEach(dimensionDesc -> { + namePair.put(dimensionDesc.getBizName(), dimensionDesc.getName()); + nameTypePair.put(dimensionDesc.getBizName(), dimensionDesc.getSemanticType()); + }); } - private void processColumn( - QueryColumn column, - Map namePair, - Map nameTypePair, - Map metricRespMap) { + private void processColumn(QueryColumn column, Map namePair, + Map nameTypePair, Map metricRespMap) { String nameEn = getName(column.getNameEn().toLowerCase()); if (nameEn.contains(JOIN_UNDERLINE)) { nameEn = nameEn.split(JOIN_UNDERLINE)[1]; @@ -126,14 +114,10 @@ public class QueryUtils { if (StringUtils.isBlank(type)) { return false; } - return type.equalsIgnoreCase("int") - || type.equalsIgnoreCase("bigint") - || type.equalsIgnoreCase("float") - || type.equalsIgnoreCase("double") - || type.equalsIgnoreCase("numeric") - || type.toLowerCase().startsWith("decimal") - || type.toLowerCase().startsWith("uint") - || type.toLowerCase().startsWith("int"); + return type.equalsIgnoreCase("int") || type.equalsIgnoreCase("bigint") + || type.equalsIgnoreCase("float") || type.equalsIgnoreCase("double") + || type.equalsIgnoreCase("numeric") || type.toLowerCase().startsWith("decimal") + || type.toLowerCase().startsWith("uint") || type.toLowerCase().startsWith("int"); } private String getName(String nameEn) { @@ -156,23 +140,19 @@ public class QueryUtils { return null; } - public QueryStatement sqlParserUnion( - QueryMultiStructReq queryMultiStructCmd, List sqlParsers) { + public QueryStatement sqlParserUnion(QueryMultiStructReq queryMultiStructCmd, + List sqlParsers) { QueryStatement sqlParser = new QueryStatement(); StringBuilder unionSqlBuilder = new StringBuilder(); for (int i = 0; i < sqlParsers.size(); i++) { - String selectStr = - SqlGenerateUtils.getUnionSelect( - queryMultiStructCmd.getQueryStructReqs().get(i)); - unionSqlBuilder.append( - String.format( - "select %s from ( %s ) sub_sql_%s", - selectStr, sqlParsers.get(i).getSql(), i)); + String selectStr = SqlGenerateUtils + .getUnionSelect(queryMultiStructCmd.getQueryStructReqs().get(i)); + unionSqlBuilder.append(String.format("select %s from ( %s ) sub_sql_%s", selectStr, + sqlParsers.get(i).getSql(), i)); unionSqlBuilder.append(UNIONALL); } - String unionSql = - unionSqlBuilder.substring( - 0, unionSqlBuilder.length() - Constants.UNIONALL.length()); + String unionSql = unionSqlBuilder.substring(0, + unionSqlBuilder.length() - Constants.UNIONALL.length()); sqlParser.setSql(unionSql); log.info("union sql parser:{}", sqlParser); return sqlParser; diff --git a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java index 9496aa71c..2a468772f 100644 --- a/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java +++ b/headless/server/src/main/java/com/tencent/supersonic/headless/server/utils/StatUtils.java @@ -65,15 +65,12 @@ public class StatUtils { QueryStat queryStatInfo = get(); queryStatInfo.setElapsedMs(System.currentTimeMillis() - queryStatInfo.getStartTime()); queryStatInfo.setQueryState(state.getStatus()); - CompletableFuture.runAsync( - () -> { - statRepository.createRecord(queryStatInfo); - }) - .exceptionally( - exception -> { - log.warn("queryStatInfo, exception:", exception); - return null; - }); + CompletableFuture.runAsync(() -> { + statRepository.createRecord(queryStatInfo); + }).exceptionally(exception -> { + log.warn("queryStatInfo, exception:", exception); + return null; + }); remove(); } @@ -111,10 +108,7 @@ public class StatUtils { String user = getUserName(facadeUser); try { - queryStatInfo - .setTraceId(traceId) - .setDataSetId(queryTagReq.getDataSetId()) - .setUser(user) + queryStatInfo.setTraceId(traceId).setDataSetId(queryTagReq.getDataSetId()).setUser(user) .setQueryType(QueryMethod.STRUCT.getValue()) .setQueryTypeBack(QueryTypeBack.NORMAL.getState()) .setQueryStructCmd(queryTagReq.toString()) @@ -124,11 +118,9 @@ public class StatUtils { .setGroupByCols(objectMapper.writeValueAsString(queryTagReq.getGroups())) .setAggCols(objectMapper.writeValueAsString(queryTagReq.getAggregators())) .setOrderByCols(objectMapper.writeValueAsString(queryTagReq.getOrders())) - .setFilterCols( - objectMapper.writeValueAsString( - sqlFilterUtils.getFiltersCol(queryTagReq.getTagFilters()))) - .setUseResultCache(true) - .setUseSqlCache(true) + .setFilterCols(objectMapper.writeValueAsString( + sqlFilterUtils.getFiltersCol(queryTagReq.getTagFilters()))) + .setUseResultCache(true).setUseSqlCache(true) .setMetrics(objectMapper.writeValueAsString(metrics)) .setDimensions(objectMapper.writeValueAsString(dimensions)) .setQueryOptMode(QueryOptMode.NONE.name()); @@ -150,18 +142,13 @@ public class StatUtils { String userName = getUserName(facadeUser); try { - queryStatInfo - .setTraceId("") - .setUser(userName) - .setDataSetId(querySqlReq.getDataSetId()) + queryStatInfo.setTraceId("").setUser(userName).setDataSetId(querySqlReq.getDataSetId()) .setQueryType(QueryMethod.SQL.getValue()) .setQueryTypeBack(QueryTypeBack.NORMAL.getState()) .setQuerySqlCmd(querySqlReq.toString()) .setQuerySqlCmdMd5(DigestUtils.md5Hex(querySqlReq.toString())) - .setStartTime(System.currentTimeMillis()) - .setUseResultCache(true) - .setUseSqlCache(true) - .setMetrics(objectMapper.writeValueAsString(aggFields)) + .setStartTime(System.currentTimeMillis()).setUseResultCache(true) + .setUseSqlCache(true).setMetrics(objectMapper.writeValueAsString(aggFields)) .setDimensions(objectMapper.writeValueAsString(dimensions)); if (!CollectionUtils.isEmpty(querySqlReq.getModelIds())) { queryStatInfo.setModelId(querySqlReq.getModelIds().get(0)); @@ -183,11 +170,8 @@ public class StatUtils { String user = getUserName(facadeUser); try { - queryStatInfo - .setTraceId(traceId) - .setDataSetId(queryStructReq.getDataSetId()) - .setUser(user) - .setQueryType(QueryMethod.STRUCT.getValue()) + queryStatInfo.setTraceId(traceId).setDataSetId(queryStructReq.getDataSetId()) + .setUser(user).setQueryType(QueryMethod.STRUCT.getValue()) .setQueryTypeBack(QueryTypeBack.NORMAL.getState()) .setQueryStructCmd(queryStructReq.toString()) .setQueryStructCmdMd5(DigestUtils.md5Hex(queryStructReq.toString())) @@ -196,12 +180,9 @@ public class StatUtils { .setGroupByCols(objectMapper.writeValueAsString(queryStructReq.getGroups())) .setAggCols(objectMapper.writeValueAsString(queryStructReq.getAggregators())) .setOrderByCols(objectMapper.writeValueAsString(queryStructReq.getOrders())) - .setFilterCols( - objectMapper.writeValueAsString( - sqlFilterUtils.getFiltersCol( - queryStructReq.getOriginalFilter()))) - .setUseResultCache(true) - .setUseSqlCache(true) + .setFilterCols(objectMapper.writeValueAsString( + sqlFilterUtils.getFiltersCol(queryStructReq.getOriginalFilter()))) + .setUseResultCache(true).setUseSqlCache(true) .setMetrics(objectMapper.writeValueAsString(metrics)) .setDimensions(objectMapper.writeValueAsString(dimensions)) .setQueryOptMode(QueryOptMode.NONE.name()); @@ -214,15 +195,12 @@ public class StatUtils { StatUtils.set(queryStatInfo); } - private List getFieldNames( - List allFields, List schemaItems) { - Set fieldNames = - schemaItems.stream() - .map(dimSchemaResp -> dimSchemaResp.getBizName()) - .collect(Collectors.toSet()); + private List getFieldNames(List allFields, + List schemaItems) { + Set fieldNames = schemaItems.stream() + .map(dimSchemaResp -> dimSchemaResp.getBizName()).collect(Collectors.toSet()); if (!CollectionUtils.isEmpty(fieldNames)) { - return allFields.stream() - .filter(fieldName -> fieldNames.contains(fieldName)) + return allFields.stream().filter(fieldName -> fieldNames.contains(fieldName)) .collect(Collectors.toList()); } return new ArrayList<>(); diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/aspect/MetricDrillDownCheckerTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/aspect/MetricDrillDownCheckerTest.java index 0ec5da101..6d6e5d060 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/aspect/MetricDrillDownCheckerTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/aspect/MetricDrillDownCheckerTest.java @@ -31,8 +31,7 @@ public class MetricDrillDownCheckerTest { MetricDrillDownChecker metricDrillDownChecker = new MetricDrillDownChecker(); String sql = "select page, sum(pv) from t_1 group by page"; SemanticSchemaResp semanticSchemaResp = mockModelSchemaResp(); - assertThrows( - InvalidArgumentException.class, + assertThrows(InvalidArgumentException.class, () -> metricDrillDownChecker.checkQuery(semanticSchemaResp, sql)); } @@ -41,8 +40,7 @@ public class MetricDrillDownCheckerTest { MetricDrillDownChecker metricDrillDownChecker = new MetricDrillDownChecker(); String sql = "select user_name, count(distinct uv) from t_1 group by user_name"; SemanticSchemaResp semanticSchemaResp = mockModelSchemaResp(); - assertThrows( - InvalidArgumentException.class, + assertThrows(InvalidArgumentException.class, () -> metricDrillDownChecker.checkQuery(semanticSchemaResp, sql)); } @@ -79,26 +77,23 @@ public class MetricDrillDownCheckerTest { } private List mockDimensions() { - return Lists.newArrayList( - DataUtils.mockDimension(1L, "user_name", "用户名"), + return Lists.newArrayList(DataUtils.mockDimension(1L, "user_name", "用户名"), DataUtils.mockDimension(2L, "department", "部门"), DataUtils.mockDimension(3L, "page", "页面")); } private List mockMetrics() { - return Lists.newArrayList( - DataUtils.mockMetric( - 1L, - "pv", - "访问次数", - Lists.newArrayList(new DrillDownDimension(1L), new DrillDownDimension(2L))), - DataUtils.mockMetric( - 2L, "uv", "访问用户数", Lists.newArrayList(new DrillDownDimension(2L, true)))); + return Lists + .newArrayList( + DataUtils.mockMetric(1L, "pv", "访问次数", + Lists.newArrayList(new DrillDownDimension(1L), + new DrillDownDimension(2L))), + DataUtils.mockMetric(2L, "uv", "访问用户数", + Lists.newArrayList(new DrillDownDimension(2L, true)))); } private List mockMetricsNoDrillDownSetting() { - return Lists.newArrayList( - DataUtils.mockMetric(1L, "pv", Lists.newArrayList()), + return Lists.newArrayList(DataUtils.mockMetric(1L, "pv", Lists.newArrayList()), DataUtils.mockMetric(2L, "uv", Lists.newArrayList())); } } diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/calcite/HeadlessParserServiceTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/calcite/HeadlessParserServiceTest.java index dc46478d8..40cd3d5db 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/calcite/HeadlessParserServiceTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/calcite/HeadlessParserServiceTest.java @@ -29,8 +29,8 @@ class HeadlessParserServiceTest { private static Map headlessSchemaMap = new HashMap<>(); - public static SqlParserResp parser( - SemanticSchema semanticSchema, MetricQueryParam metricQueryParam, boolean isAgg) { + public static SqlParserResp parser(SemanticSchema semanticSchema, + MetricQueryParam metricQueryParam, boolean isAgg) { SqlParserResp sqlParser = new SqlParserResp(); try { if (semanticSchema == null) { @@ -41,9 +41,8 @@ class HeadlessParserServiceTest { QueryStatement queryStatement = new QueryStatement(); queryStatement.setMetricQueryParam(metricQueryParam); aggBuilder.explain(queryStatement, AggOption.getAggregation(!isAgg)); - EngineType engineType = - EngineType.fromString( - semanticSchema.getSemanticModel().getDatabase().getType()); + EngineType engineType = EngineType + .fromString(semanticSchema.getSemanticModel().getDatabase().getType()); sqlParser.setSql(aggBuilder.getSql(engineType)); sqlParser.setSourceId(aggBuilder.getSourceId()); } catch (Exception e) { @@ -125,8 +124,8 @@ class HeadlessParserServiceTest { datasource.setIdentifiers(identifies); SemanticSchema semanticSchema = SemanticSchema.newBuilder("1").build(); - SemanticSchemaManager.update( - semanticSchema, SemanticSchemaManager.getDatasource(datasource)); + SemanticSchemaManager.update(semanticSchema, + SemanticSchemaManager.getDatasource(datasource)); DimensionYamlTpl dimension1 = new DimensionYamlTpl(); dimension1.setExpr("page"); @@ -135,9 +134,7 @@ class HeadlessParserServiceTest { List dimensionYamlTpls = new ArrayList<>(); dimensionYamlTpls.add(dimension1); - SemanticSchemaManager.update( - semanticSchema, - "s2_pv_uv_statis", + SemanticSchemaManager.update(semanticSchema, "s2_pv_uv_statis", SemanticSchemaManager.getDimensions(dimensionYamlTpls)); MetricYamlTpl metric1 = new MetricYamlTpl(); @@ -183,13 +180,8 @@ class HeadlessParserServiceTest { addDepartment(semanticSchema); MetricQueryParam metricCommand2 = new MetricQueryParam(); - metricCommand2.setDimensions( - new ArrayList<>( - Arrays.asList( - "sys_imp_date", - "user_name__department", - "user_name", - "user_name__page"))); + metricCommand2.setDimensions(new ArrayList<>(Arrays.asList("sys_imp_date", + "user_name__department", "user_name", "user_name__page"))); metricCommand2.setMetrics(new ArrayList<>(Arrays.asList("pv"))); metricCommand2.setWhere( "user_name = 'ab' and (sys_imp_date >= '2023-02-28' and sys_imp_date <= '2023-05-28') "); @@ -246,9 +238,8 @@ class HeadlessParserServiceTest { identifies.add(identify); datasource.setIdentifiers(identifies); - semanticSchema - .getDatasource() - .put("user_department", SemanticSchemaManager.getDatasource(datasource)); + semanticSchema.getDatasource().put("user_department", + SemanticSchemaManager.getDatasource(datasource)); DimensionYamlTpl dimension1 = new DimensionYamlTpl(); dimension1.setExpr("department"); @@ -257,8 +248,7 @@ class HeadlessParserServiceTest { List dimensionYamlTpls = new ArrayList<>(); dimensionYamlTpls.add(dimension1); - semanticSchema - .getDimension() - .put("user_department", SemanticSchemaManager.getDimensions(dimensionYamlTpls)); + semanticSchema.getDimension().put("user_department", + SemanticSchemaManager.getDimensions(dimensionYamlTpls)); } } diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java index e494ca516..7d0b1161e 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/DownloadServiceImplTest.java @@ -35,17 +35,15 @@ class DownloadServiceImplTest { return modelSchemaResp; } - private MetricSchemaResp mockMetric( - Long id, String bizName, String name, List drillDownloadDimensions) { + private MetricSchemaResp mockMetric(Long id, String bizName, String name, + List drillDownloadDimensions) { MetricSchemaResp metricResp = new MetricSchemaResp(); metricResp.setId(id); metricResp.setBizName(bizName); metricResp.setName(name); RelateDimension relateDimension = new RelateDimension(); - relateDimension.setDrillDownDimensions( - drillDownloadDimensions.stream() - .map(DrillDownDimension::new) - .collect(Collectors.toList())); + relateDimension.setDrillDownDimensions(drillDownloadDimensions.stream() + .map(DrillDownDimension::new).collect(Collectors.toList())); metricResp.setRelateDimension(relateDimension); return metricResp; } @@ -93,8 +91,8 @@ class DownloadServiceImplTest { return semanticQueryResp; } - private static Map createMap( - String sysImpDate, String d1, String d2, String m1) { + private static Map createMap(String sysImpDate, String d1, String d2, + String m1) { Map map = new HashMap<>(); map.put("sys_imp_date", sysImpDate); map.put("user_name", d1); diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/MetricServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/MetricServiceImplTest.java index 811442935..6d8cb2bfa 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/MetricServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/MetricServiceImplTest.java @@ -63,8 +63,8 @@ public class MetricServiceImplTest { Assertions.assertEquals(expectedMetricResp, actualMetricResp); } - private MetricService mockMetricService( - MetricRepository metricRepository, ModelService modelService) { + private MetricService mockMetricService(MetricRepository metricRepository, + ModelService modelService) { AliasGenerateHelper aliasGenerateHelper = Mockito.mock(AliasGenerateHelper.class); CollectService collectService = Mockito.mock(CollectService.class); ApplicationEventPublisher eventPublisher = Mockito.mock(ApplicationEventPublisher.class); @@ -72,15 +72,8 @@ public class MetricServiceImplTest { DimensionService dimensionService = Mockito.mock(DimensionService.class); TagMetaService tagMetaService = Mockito.mock(TagMetaService.class); ChatLayerService chatLayerService = Mockito.mock(ChatLayerService.class); - return new MetricServiceImpl( - metricRepository, - modelService, - aliasGenerateHelper, - collectService, - dataSetService, - eventPublisher, - dimensionService, - tagMetaService, + return new MetricServiceImpl(metricRepository, modelService, aliasGenerateHelper, + collectService, dataSetService, eventPublisher, dimensionService, tagMetaService, chatLayerService); } @@ -99,20 +92,14 @@ public class MetricServiceImplTest { dataFormat.setNeedMultiply100(false); metricReq.setDataFormat(dataFormat); MetricDefineByMeasureParams typeParams = new MetricDefineByMeasureParams(); - typeParams.setMeasures( - Lists.newArrayList( - new MeasureParam("s2_pv", "department='hr'"), - new MeasureParam("s2_uv", "department='hr'"))); + typeParams.setMeasures(Lists.newArrayList(new MeasureParam("s2_pv", "department='hr'"), + new MeasureParam("s2_uv", "department='hr'"))); typeParams.setExpr("s2_pv/s2_uv"); metricReq.setMetricDefineByMeasureParams(typeParams); metricReq.setClassifications(Lists.newArrayList("核心指标")); - metricReq.setRelateDimension( - RelateDimension.builder() - .drillDownDimensions( - Lists.newArrayList( - new DrillDownDimension(1L), - new DrillDownDimension(1L, false))) - .build()); + metricReq.setRelateDimension(RelateDimension.builder().drillDownDimensions( + Lists.newArrayList(new DrillDownDimension(1L), new DrillDownDimension(1L, false))) + .build()); metricReq.setSensitiveLevel(SensitiveLevelEnum.LOW.getCode()); metricReq.setExt(new HashMap<>()); return metricReq; @@ -133,20 +120,14 @@ public class MetricServiceImplTest { dataFormat.setNeedMultiply100(false); metricResp.setDataFormat(dataFormat); MetricDefineByMeasureParams typeParams = new MetricDefineByMeasureParams(); - typeParams.setMeasures( - Lists.newArrayList( - new MeasureParam("s2_pv", "department='hr'"), - new MeasureParam("s2_uv", "department='hr'"))); + typeParams.setMeasures(Lists.newArrayList(new MeasureParam("s2_pv", "department='hr'"), + new MeasureParam("s2_uv", "department='hr'"))); typeParams.setExpr("s2_pv/s2_uv"); metricResp.setMetricDefineByMeasureParams(typeParams); metricResp.setClassifications("核心指标"); - metricResp.setRelateDimension( - RelateDimension.builder() - .drillDownDimensions( - Lists.newArrayList( - new DrillDownDimension(1L), - new DrillDownDimension(1L, false))) - .build()); + metricResp.setRelateDimension(RelateDimension.builder().drillDownDimensions( + Lists.newArrayList(new DrillDownDimension(1L), new DrillDownDimension(1L, false))) + .build()); metricResp.setSensitiveLevel(SensitiveLevelEnum.LOW.getCode()); metricResp.setExt(new HashMap<>()); metricResp.setTypeEnum(TypeEnums.METRIC); @@ -163,10 +144,8 @@ public class MetricServiceImplTest { metricReq.setBizName("pv"); metricReq.setMetricDefineType(MetricDefineType.MEASURE); MetricDefineByMeasureParams typeParams = new MetricDefineByMeasureParams(); - typeParams.setMeasures( - Lists.newArrayList( - new MeasureParam("s2_pv", "department='hr'"), - new MeasureParam("s2_uv", "department='hr'"))); + typeParams.setMeasures(Lists.newArrayList(new MeasureParam("s2_pv", "department='hr'"), + new MeasureParam("s2_uv", "department='hr'"))); typeParams.setExpr("s2_pv/s2_uv"); metricReq.setMetricDefineByMeasureParams(typeParams); return metricReq; diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java index 7c6211dc2..df5442359 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/service/ModelServiceImplTest.java @@ -75,15 +75,8 @@ class ModelServiceImplTest { UserService userService = Mockito.mock(UserService.class); DateInfoRepository dateInfoRepository = Mockito.mock(DateInfoRepository.class); DataSetService viewService = Mockito.mock(DataSetService.class); - return new ModelServiceImpl( - modelRepository, - databaseService, - dimensionService, - metricService, - domainService, - userService, - viewService, - dateInfoRepository); + return new ModelServiceImpl(modelRepository, databaseService, dimensionService, + metricService, domainService, userService, viewService, dateInfoRepository); } private ModelReq mockModelReq() { @@ -156,9 +149,8 @@ class ModelServiceImplTest { measures.add(measure2); modelDetail.setMeasures(measures); - modelDetail.setSqlQuery( - "SELECT imp_date_a, user_name_a, page_a, 1 as pv_a," - + " user_name as uv_a FROM s2_pv_uv_statis"); + modelDetail.setSqlQuery("SELECT imp_date_a, user_name_a, page_a, 1 as pv_a," + + " user_name as uv_a FROM s2_pv_uv_statis"); modelDetail.setQueryType("sql_query"); modelReq.setDomainId(1L); modelReq.setFilterSql("where user_name = 'tom'"); @@ -189,9 +181,8 @@ class ModelServiceImplTest { Measure measure2 = new Measure("访问人数", "uv", AggOperatorEnum.COUNT_DISTINCT.name(), 1); measures.add(measure2); modelDetail.setMeasures(measures); - modelDetail.setSqlQuery( - "SELECT imp_date, user_name, page, 1 as pv, " - + "user_name as uv FROM s2_pv_uv_statis"); + modelDetail.setSqlQuery("SELECT imp_date, user_name, page, 1 as pv, " + + "user_name as uv FROM s2_pv_uv_statis"); modelDetail.setQueryType("sql_query"); modelReq.setModelDetail(modelDetail); return modelReq; @@ -274,19 +265,14 @@ class ModelServiceImplTest { measure1.setExpr("pv_a"); measures.add(measure1); - Measure measure2 = - new Measure( - "访问人数_a", - "s2_pv_uv_statis_a_uv_a", - AggOperatorEnum.COUNT_DISTINCT.name(), - 1); + Measure measure2 = new Measure("访问人数_a", "s2_pv_uv_statis_a_uv_a", + AggOperatorEnum.COUNT_DISTINCT.name(), 1); measure2.setExpr("uv_a"); measures.add(measure2); modelDetail.setMeasures(measures); - modelDetail.setSqlQuery( - "SELECT imp_date_a, user_name_a, page_a, 1 as pv_a, " - + "user_name as uv_a FROM s2_pv_uv_statis"); + modelDetail.setSqlQuery("SELECT imp_date_a, user_name_a, page_a, 1 as pv_a, " + + "user_name as uv_a FROM s2_pv_uv_statis"); modelDetail.setQueryType("sql_query"); modelResp.setModelDetail(modelDetail); modelResp.setId(1L); diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelperTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelperTest.java index 64a66f4ce..665b43153 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelperTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/AliasGenerateHelperTest.java @@ -16,14 +16,8 @@ class AliasGenerateHelperTest { void extractJsonStringFromAiMessage2() { /** ``` { "name": "Alice", "age": 25, "city": "New York" } ``` */ - String testJson2 = - "```\n" - + "{\n" - + " \"name\": \"Alice\",\n" - + " \"age\": 25,\n" - + " \"city\": \"New York\"\n" - + "}\n" - + "```"; + String testJson2 = "```\n" + "{\n" + " \"name\": \"Alice\",\n" + " \"age\": 25,\n" + + " \"city\": \"New York\"\n" + "}\n" + "```"; AliasGenerateHelper.extractJsonStringFromAiMessage(testJson2); } @@ -37,14 +31,9 @@ class AliasGenerateHelperTest { */ String testJson3 = "I understand that you want me to generate a JSON object with two properties: " - + "`tran` and `alias`...." - + "```json\n" - + "{\n" - + " \"name\": \"Alice\",\n" - + " \"age\": 25,\n" - + " \"city\": \"New York\"\n" - + "}\n" - + "```" + + "`tran` and `alias`...." + "```json\n" + "{\n" + + " \"name\": \"Alice\",\n" + " \"age\": 25,\n" + + " \"city\": \"New York\"\n" + "}\n" + "```" + "Please let me know if there is any problem."; AliasGenerateHelper.extractJsonStringFromAiMessage(testJson3); } @@ -54,14 +43,8 @@ class AliasGenerateHelperTest { String testJson4 = "Based on the provided JSON-schema, I will construct the answer as follows:\n" - + "\n" - + "[\n" - + " \"作者名称\",\n" - + " \"作者姓名\",\n" - + " \"创作者\",\n" - + " \"作者信息\"\n" - + "]\n" - + "\n" + + "\n" + "[\n" + " \"作者名称\",\n" + " \"作者姓名\",\n" + " \"创作者\",\n" + + " \"作者信息\"\n" + "]\n" + "\n" + "This answer conforms to the format described in the JSON-schema"; AliasGenerateHelper.extractJsonStringFromAiMessage(testJson4); } diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataTransformUtilsTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataTransformUtilsTest.java index 96dba41f1..e90a93786 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataTransformUtilsTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataTransformUtilsTest.java @@ -28,8 +28,8 @@ class DataTransformUtilsTest { Assertions.assertEquals(3, resultData.size()); } - private static Map createMap( - String sysImpDate, String d1, String d2, String m1) { + private static Map createMap(String sysImpDate, String d1, String d2, + String m1) { Map map = new HashMap<>(); map.put("sys_imp_date", sysImpDate); map.put("d1", d1); diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataUtils.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataUtils.java index c56c63bd2..c59ed42a4 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataUtils.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/DataUtils.java @@ -26,8 +26,8 @@ public class DataUtils { return metricSchemaResp; } - public static MetricSchemaResp mockMetric( - Long id, String bizName, String name, List drillDownDimensions) { + public static MetricSchemaResp mockMetric(Long id, String bizName, String name, + List drillDownDimensions) { MetricSchemaResp metricSchemaResp = new MetricSchemaResp(); metricSchemaResp.setId(id); metricSchemaResp.setName(name); @@ -37,8 +37,8 @@ public class DataUtils { return metricSchemaResp; } - public static MetricSchemaResp mockMetric( - Long id, String bizName, List drillDownDimensions) { + public static MetricSchemaResp mockMetric(Long id, String bizName, + List drillDownDimensions) { return mockMetric(id, bizName, null, drillDownDimensions); } } diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/QueryNLReqBuilderTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/QueryNLReqBuilderTest.java index 409579961..7cec555f9 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/QueryNLReqBuilderTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/QueryNLReqBuilderTest.java @@ -50,29 +50,23 @@ class QueryNLReqBuilderTest { queryStructReq.setOrders(orders); QuerySqlReq querySQLReq = queryStructReq.convert(); - Assert.assertEquals( - "SELECT department, SUM(pv) AS pv FROM 内容库 " - + "WHERE (sys_imp_date IN ('2023-08-01')) GROUP " - + "BY department ORDER BY uv LIMIT 2000", - querySQLReq.getSql()); + Assert.assertEquals("SELECT department, SUM(pv) AS pv FROM 内容库 " + + "WHERE (sys_imp_date IN ('2023-08-01')) GROUP " + + "BY department ORDER BY uv LIMIT 2000", querySQLReq.getSql()); queryStructReq.setQueryType(QueryType.DETAIL); querySQLReq = queryStructReq.convert(); - Assert.assertEquals( - "SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) " - + "ORDER BY uv LIMIT 2000", - querySQLReq.getSql()); + Assert.assertEquals("SELECT department, pv FROM 内容库 WHERE (sys_imp_date IN ('2023-08-01')) " + + "ORDER BY uv LIMIT 2000", querySQLReq.getSql()); } private void init() { MockedStatic mockContextUtils = Mockito.mockStatic(ContextUtils.class); SqlFilterUtils sqlFilterUtils = new SqlFilterUtils(); - mockContextUtils - .when(() -> ContextUtils.getBean(SqlFilterUtils.class)) + mockContextUtils.when(() -> ContextUtils.getBean(SqlFilterUtils.class)) .thenReturn(sqlFilterUtils); DateModeUtils dateModeUtils = new DateModeUtils(); - mockContextUtils - .when(() -> ContextUtils.getBean(DateModeUtils.class)) + mockContextUtils.when(() -> ContextUtils.getBean(DateModeUtils.class)) .thenReturn(dateModeUtils); dateModeUtils.setSysDateCol("sys_imp_date"); dateModeUtils.setSysDateWeekCol("sys_imp_week"); diff --git a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java index 07ea12e26..b5ab31878 100644 --- a/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java +++ b/headless/server/src/test/java/com/tencent/supersonic/headless/server/utils/SqlVariableParseUtilsTest.java @@ -15,9 +15,8 @@ public class SqlVariableParseUtilsTest { @Test void testParseSql_defaultVariableValue() { String sql = "select * from t_$interval$ where id = $id$ and name = $name$"; - List variables = - Lists.newArrayList( - mockNumSqlVariable(), mockExprSqlVariable(), mockStrSqlVariable()); + List variables = Lists.newArrayList(mockNumSqlVariable(), + mockExprSqlVariable(), mockStrSqlVariable()); String actualSql = SqlVariableParseUtils.parse(sql, variables, Lists.newArrayList()); String expectedSql = "select * from t_d where id = 1 and name = 'tom'"; Assertions.assertEquals(expectedSql, actualSql); @@ -26,9 +25,8 @@ public class SqlVariableParseUtilsTest { @Test void testParseSql() { String sql = "select * from t_$interval$ where id = $id$ and name = $name$"; - List variables = - Lists.newArrayList( - mockNumSqlVariable(), mockExprSqlVariable(), mockStrSqlVariable()); + List variables = Lists.newArrayList(mockNumSqlVariable(), + mockExprSqlVariable(), mockStrSqlVariable()); List params = Lists.newArrayList(mockIdParam(), mockNameParam(), mockIntervalParam()); String actualSql = SqlVariableParseUtils.parse(sql, variables, params); @@ -48,8 +46,8 @@ public class SqlVariableParseUtilsTest { return mockSqlVariable("interval", VariableValueType.EXPR, "d"); } - private SqlVariable mockSqlVariable( - String name, VariableValueType variableValueType, Object value) { + private SqlVariable mockSqlVariable(String name, VariableValueType variableValueType, + Object value) { SqlVariable sqlVariable = new SqlVariable(); sqlVariable.setName(name); sqlVariable.setValueType(variableValueType); diff --git a/java-formatter.xml b/java-formatter.xml new file mode 100644 index 000000000..b5e2fc2b0 --- /dev/null +++ b/java-formatter.xml @@ -0,0 +1,337 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/launchers/chat/src/main/java/com/tencent/supersonic/ChatLauncher.java b/launchers/chat/src/main/java/com/tencent/supersonic/ChatLauncher.java index 862fd7fbd..b70509845 100644 --- a/launchers/chat/src/main/java/com/tencent/supersonic/ChatLauncher.java +++ b/launchers/chat/src/main/java/com/tencent/supersonic/ChatLauncher.java @@ -7,8 +7,7 @@ import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration; import org.springframework.scheduling.annotation.EnableScheduling; /** Chat Launcher */ -@SpringBootApplication( - scanBasePackages = {"com.tencent.supersonic"}, +@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic"}, exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class}) @EnableScheduling public class ChatLauncher { diff --git a/launchers/common/src/main/java/com/tencent/supersonic/advice/ResponseAdvice.java b/launchers/common/src/main/java/com/tencent/supersonic/advice/ResponseAdvice.java index 103d50019..0b82e8e84 100644 --- a/launchers/common/src/main/java/com/tencent/supersonic/advice/ResponseAdvice.java +++ b/launchers/common/src/main/java/com/tencent/supersonic/advice/ResponseAdvice.java @@ -20,27 +20,23 @@ import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyAdvice; @RestControllerAdvice(annotations = RestController.class) public class ResponseAdvice implements ResponseBodyAdvice { - @Autowired private ObjectMapper objectMapper; + @Autowired + private ObjectMapper objectMapper; @Override - public boolean supports( - MethodParameter methodParameter, Class> aClass) { + public boolean supports(MethodParameter methodParameter, + Class> aClass) { return !methodParameter.getDeclaringClass().isAssignableFrom(BasicErrorController.class); } @SneakyThrows @Override - public Object beforeBodyWrite( - Object result, - MethodParameter methodParameter, - MediaType mediaType, - Class> aClass, - ServerHttpRequest serverHttpRequest, - ServerHttpResponse serverHttpResponse) { + public Object beforeBodyWrite(Object result, MethodParameter methodParameter, + MediaType mediaType, Class> aClass, + ServerHttpRequest serverHttpRequest, ServerHttpResponse serverHttpResponse) { // 判断当前请求是否是 Swagger 相关的请求 String path = serverHttpRequest.getURI().getPath(); - if (path.startsWith("/swagger") - || path.startsWith("/v3/api-docs") + if (path.startsWith("/swagger") || path.startsWith("/v3/api-docs") || path.startsWith("/v2/api-docs")) { return result; } diff --git a/launchers/common/src/main/java/com/tencent/supersonic/config/RestTemplateConfig.java b/launchers/common/src/main/java/com/tencent/supersonic/config/RestTemplateConfig.java index 2623e91db..b9c51b9d1 100644 --- a/launchers/common/src/main/java/com/tencent/supersonic/config/RestTemplateConfig.java +++ b/launchers/common/src/main/java/com/tencent/supersonic/config/RestTemplateConfig.java @@ -25,9 +25,8 @@ public class RestTemplateConfig { HttpClientBuilder.create().setRedirectStrategy(new LaxRedirectStrategy()).build(); httpRequestFactory.setHttpClient(httpClient); RestTemplate restTemplate = new RestTemplate(httpRequestFactory); - restTemplate - .getMessageConverters() - .set(1, new StringHttpMessageConverter(StandardCharsets.UTF_8)); + restTemplate.getMessageConverters().set(1, + new StringHttpMessageConverter(StandardCharsets.UTF_8)); return restTemplate; } } diff --git a/launchers/headless/src/main/java/com/tencent/supersonic/HeadlessLauncher.java b/launchers/headless/src/main/java/com/tencent/supersonic/HeadlessLauncher.java index f856c73cd..0df63ac6c 100644 --- a/launchers/headless/src/main/java/com/tencent/supersonic/HeadlessLauncher.java +++ b/launchers/headless/src/main/java/com/tencent/supersonic/HeadlessLauncher.java @@ -8,8 +8,7 @@ import org.springframework.boot.autoconfigure.mongo.MongoAutoConfiguration; /** Headless Launcher */ @Slf4j -@SpringBootApplication( - scanBasePackages = {"com.tencent.supersonic"}, +@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic"}, exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class}) public class HeadlessLauncher { diff --git a/launchers/headless/src/main/java/com/tencent/supersonic/db/MybatisConfig.java b/launchers/headless/src/main/java/com/tencent/supersonic/db/MybatisConfig.java index d34cf9b34..c627ef1ba 100644 --- a/launchers/headless/src/main/java/com/tencent/supersonic/db/MybatisConfig.java +++ b/launchers/headless/src/main/java/com/tencent/supersonic/db/MybatisConfig.java @@ -19,8 +19,8 @@ public class MybatisConfig { private static final String MAPPER_LOCATION = "classpath*:mapper/**/*.xml"; @Bean - public SqlSessionFactory sqlSessionFactory( - DataSource dataSource, PageInterceptor pageInterceptor) throws Exception { + public SqlSessionFactory sqlSessionFactory(DataSource dataSource, + PageInterceptor pageInterceptor) throws Exception { SqlSessionFactoryBean bean = new SqlSessionFactoryBean(); org.apache.ibatis.session.Configuration configuration = new org.apache.ibatis.session.Configuration(); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java b/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java index ce57affe6..a6de486bd 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/StandaloneLauncher.java @@ -8,8 +8,7 @@ import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.scheduling.annotation.EnableScheduling; import springfox.documentation.swagger2.annotations.EnableSwagger2; -@SpringBootApplication( - scanBasePackages = {"com.tencent.supersonic", "dev.langchain4j"}, +@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic", "dev.langchain4j"}, exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class}) @EnableScheduling @EnableAsync diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/config/SwaggerConfiguration.java b/launchers/standalone/src/main/java/com/tencent/supersonic/config/SwaggerConfiguration.java index 901846c78..7da17c1eb 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/config/SwaggerConfiguration.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/config/SwaggerConfiguration.java @@ -54,34 +54,25 @@ public class SwaggerConfiguration { @Value("${swagger.version}") private String version; - @Autowired private AuthenticationConfig authenticationConfig; + @Autowired + private AuthenticationConfig authenticationConfig; @Bean public Docket createRestApi() { - return new Docket(DocumentationType.OAS_30) - .apiInfo(apiInfo()) - .select() - .apis(RequestHandlerSelectors.basePackage(basePackage)) - .paths(PathSelectors.any()) - .build() - .securitySchemes(Lists.newArrayList(apiKey())); + return new Docket(DocumentationType.OAS_30).apiInfo(apiInfo()).select() + .apis(RequestHandlerSelectors.basePackage(basePackage)).paths(PathSelectors.any()) + .build().securitySchemes(Lists.newArrayList(apiKey())); } private ApiKey apiKey() { - return new ApiKey( - authenticationConfig.getTokenHttpHeaderKey(), - authenticationConfig.getTokenHttpHeaderKey(), - "header"); + return new ApiKey(authenticationConfig.getTokenHttpHeaderKey(), + authenticationConfig.getTokenHttpHeaderKey(), "header"); } private ApiInfo apiInfo() { - return new ApiInfoBuilder() - .title(title) - .description(description) - .termsOfServiceUrl(url) - .contact(new Contact(contactName, contactUrl, contactEmail)) - .version(version) + return new ApiInfoBuilder().title(title).description(description).termsOfServiceUrl(url) + .contact(new Contact(contactName, contactUrl, contactEmail)).version(version) .build(); } } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java index d9c7fb89a..d5e17dc1b 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/CspiderDemo.java @@ -150,8 +150,8 @@ public class CspiderDemo extends S2BaseDemo { modelDetail.setMeasures(Collections.emptyList()); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "SELECT f_id, artist_name, file_size, duration, formats FROM files"); + modelDetail + .setSqlQuery("SELECT f_id, artist_name, file_size, duration, formats FROM files"); modelReq.setModelDetail(modelDetail); return modelService.createModel(modelReq, user); } @@ -188,9 +188,8 @@ public class CspiderDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "SELECT imp_date, song_name, artist_name, country, f_id, g_name, " - + " rating, languages, releasedate, resolution FROM song"); + modelDetail.setSqlQuery("SELECT imp_date, song_name, artist_name, country, f_id, g_name, " + + " rating, languages, releasedate, resolution FROM song"); modelReq.setModelDetail(modelDetail); return modelService.createModel(modelReq, user); } @@ -228,8 +227,8 @@ public class CspiderDemo extends S2BaseDemo { dataSetService.save(dataSetReq, User.getFakeUser()); } - public void addModelRela_1( - DomainResp s2Domain, ModelResp genreModelResp, ModelResp artistModelResp) { + public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp, + ModelResp artistModelResp) { List joinConditions = Lists.newArrayList(); joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); @@ -241,11 +240,11 @@ public class CspiderDemo extends S2BaseDemo { modelRelaService.save(modelRelaReq, user); } - public void addModelRela_2( - DomainResp s2Domain, ModelResp filesModelResp, ModelResp artistModelResp) { + public void addModelRela_2(DomainResp s2Domain, ModelResp filesModelResp, + ModelResp artistModelResp) { List joinConditions = Lists.newArrayList(); - joinConditions.add( - new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS)); + joinConditions + .add(new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); modelRelaReq.setDomainId(s2Domain.getId()); modelRelaReq.setFromModelId(filesModelResp.getId()); @@ -255,11 +254,11 @@ public class CspiderDemo extends S2BaseDemo { modelRelaService.save(modelRelaReq, user); } - public void addModelRela_3( - DomainResp s2Domain, ModelResp songModelResp, ModelResp artistModelResp) { + public void addModelRela_3(DomainResp s2Domain, ModelResp songModelResp, + ModelResp artistModelResp) { List joinConditions = Lists.newArrayList(); - joinConditions.add( - new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS)); + joinConditions + .add(new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); modelRelaReq.setDomainId(s2Domain.getId()); modelRelaReq.setFromModelId(songModelResp.getId()); @@ -269,8 +268,8 @@ public class CspiderDemo extends S2BaseDemo { modelRelaService.save(modelRelaReq, user); } - public void addModelRela_4( - DomainResp s2Domain, ModelResp songModelResp, ModelResp genreModelResp) { + public void addModelRela_4(DomainResp s2Domain, ModelResp songModelResp, + ModelResp genreModelResp) { List joinConditions = Lists.newArrayList(); joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); @@ -282,8 +281,8 @@ public class CspiderDemo extends S2BaseDemo { modelRelaService.save(modelRelaReq, user); } - public void addModelRela_5( - DomainResp s2Domain, ModelResp songModelResp, ModelResp filesModelResp) { + public void addModelRela_5(DomainResp s2Domain, ModelResp songModelResp, + ModelResp filesModelResp) { List joinConditions = Lists.newArrayList(); joinConditions.add(new JoinCondition("f_id", "f_id", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java index 126c83349..eea1ddfc6 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/DuSQLDemo.java @@ -109,9 +109,8 @@ public class DuSQLDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "SELECT imp_date,company_id,company_name,headquarter_address," - + "company_established_time,founder,ceo,annual_turnover,employee_count FROM company"); + modelDetail.setSqlQuery("SELECT imp_date,company_id,company_name,headquarter_address," + + "company_established_time,founder,ceo,annual_turnover,employee_count FROM company"); modelReq.setModelDetail(modelDetail); modelService.createModel(modelReq, user); } @@ -138,8 +137,8 @@ public class DuSQLDemo extends S2BaseDemo { dimensions.add(new Dim("品牌名称", "brand_name", DimensionType.categorical.name(), 1)); dimensions.add( new Dim("品牌成立时间", "brand_established_time", DimensionType.categorical.name(), 1)); - dimensions.add( - new Dim("法定代表人", "legal_representative", DimensionType.categorical.name(), 1)); + dimensions + .add(new Dim("法定代表人", "legal_representative", DimensionType.categorical.name(), 1)); modelDetail.setDimensions(dimensions); List identifiers = new ArrayList<>(); @@ -152,9 +151,8 @@ public class DuSQLDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "SELECT imp_date,brand_id,brand_name,brand_established_time," - + "company_id,legal_representative,registered_capital FROM brand"); + modelDetail.setSqlQuery("SELECT imp_date,brand_id,brand_name,brand_established_time," + + "company_id,legal_representative,registered_capital FROM brand"); modelReq.setModelDetail(modelDetail); modelService.createModel(modelReq, user); } @@ -193,9 +191,8 @@ public class DuSQLDemo extends S2BaseDemo { modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "SELECT imp_date,company_id,brand_id,revenue_proportion," - + "profit_proportion,expenditure_proportion FROM company_revenue"); + modelDetail.setSqlQuery("SELECT imp_date,company_id,brand_id,revenue_proportion," + + "profit_proportion,expenditure_proportion FROM company_revenue"); modelReq.setModelDetail(modelDetail); modelService.createModel(modelReq, user); MetricResp metricResp = metricService.getMetric(13L, user); @@ -235,17 +232,15 @@ public class DuSQLDemo extends S2BaseDemo { List measures = new ArrayList<>(); measures.add(new Measure("营收", "revenue", AggOperatorEnum.SUM.name(), 1)); measures.add(new Measure("利润", "profit", AggOperatorEnum.SUM.name(), 1)); - measures.add( - new Measure( - "营收同比增长", "revenue_growth_year_on_year", AggOperatorEnum.SUM.name(), 1)); + measures.add(new Measure("营收同比增长", "revenue_growth_year_on_year", + AggOperatorEnum.SUM.name(), 1)); measures.add( new Measure("利润同比增长", "profit_growth_year_on_year", AggOperatorEnum.SUM.name(), 1)); modelDetail.setMeasures(measures); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "SELECT imp_date,year_time,brand_id,revenue,profit," - + "revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue"); + modelDetail.setSqlQuery("SELECT imp_date,year_time,brand_id,revenue,profit," + + "revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue"); modelReq.setModelDetail(modelDetail); modelService.createModel(modelReq, user); } @@ -257,20 +252,15 @@ public class DuSQLDemo extends S2BaseDemo { dataSetReq.setDomainId(4L); dataSetReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等"); dataSetReq.setAdmins(Lists.newArrayList("admin")); - List viewModelConfigs = - Lists.newArrayList( - new DataSetModelConfig( - 9L, - Lists.newArrayList(16L, 17L, 18L, 19L, 20L), - Lists.newArrayList(10L, 11L)), - new DataSetModelConfig( - 10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)), - new DataSetModelConfig( - 11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)), - new DataSetModelConfig( - 12L, - Lists.newArrayList(24L), - Lists.newArrayList(16L, 17L, 18L, 19L))); + List viewModelConfigs = Lists.newArrayList( + new DataSetModelConfig(9L, Lists.newArrayList(16L, 17L, 18L, 19L, 20L), + Lists.newArrayList(10L, 11L)), + new DataSetModelConfig(10L, Lists.newArrayList(21L, 22L, 23L), + Lists.newArrayList(12L)), + new DataSetModelConfig(11L, Lists.newArrayList(), + Lists.newArrayList(13L, 14L, 15L)), + new DataSetModelConfig(12L, Lists.newArrayList(24L), + Lists.newArrayList(16L, 17L, 18L, 19L))); DataSetDetail dsDetail = new DataSetDetail(); dsDetail.setDataSetModelConfigs(viewModelConfigs); @@ -289,8 +279,8 @@ public class DuSQLDemo extends S2BaseDemo { public void addModelRela_1() { List joinConditions = Lists.newArrayList(); - joinConditions.add( - new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS)); + joinConditions + .add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); modelRelaReq.setDomainId(4L); modelRelaReq.setFromModelId(9L); @@ -302,8 +292,8 @@ public class DuSQLDemo extends S2BaseDemo { public void addModelRela_2() { List joinConditions = Lists.newArrayList(); - joinConditions.add( - new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS)); + joinConditions + .add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); modelRelaReq.setDomainId(4L); modelRelaReq.setFromModelId(9L); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java index c381795ea..16fef0a2d 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2ArtistDemo.java @@ -87,9 +87,8 @@ public class S2ArtistDemo extends S2BaseDemo { return domainService.createDomain(domainReq, user); } - public ModelResp addModel( - DomainResp singerDomain, DatabaseResp s2Database, TagObjectResp singerTagObject) - throws Exception { + public ModelResp addModel(DomainResp singerDomain, DatabaseResp s2Database, + TagObjectResp singerTagObject) throws Exception { ModelReq modelReq = new ModelReq(); modelReq.setName("歌手库"); modelReq.setBizName("singer"); @@ -119,25 +118,20 @@ public class S2ArtistDemo extends S2BaseDemo { Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1); modelDetail.setMeasures(Lists.newArrayList(measure1, measure2, measure3)); modelDetail.setQueryType("sql_query"); - modelDetail.setSqlQuery( - "select singer_name, act_area, song_name, genre, " - + "js_play_cnt, down_cnt, favor_cnt from singer"); + modelDetail.setSqlQuery("select singer_name, act_area, song_name, genre, " + + "js_play_cnt, down_cnt, favor_cnt from singer"); modelReq.setModelDetail(modelDetail); return modelService.createModel(modelReq, user); } private void addTags(ModelResp model) { - addTag( - dimensionService.getDimension("act_area", model.getId()).getId(), + addTag(dimensionService.getDimension("act_area", model.getId()).getId(), TagDefineType.DIMENSION); - addTag( - dimensionService.getDimension("song_name", model.getId()).getId(), + addTag(dimensionService.getDimension("song_name", model.getId()).getId(), TagDefineType.DIMENSION); - addTag( - dimensionService.getDimension("genre", model.getId()).getId(), + addTag(dimensionService.getDimension("genre", model.getId()).getId(), TagDefineType.DIMENSION); - addTag( - dimensionService.getDimension("singer_name", model.getId()).getId(), + addTag(dimensionService.getDimension("singer_name", model.getId()).getId(), TagDefineType.DIMENSION); addTag(metricService.getMetric(model.getId(), "js_play_cnt").getId(), TagDefineType.METRIC); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java index e8b833ef9..82d143458 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2BaseDemo.java @@ -49,25 +49,44 @@ public abstract class S2BaseDemo implements CommandLineRunner { protected DatabaseResp demoDatabaseResp; protected User user = User.getFakeUser(); - @Autowired protected DatabaseService databaseService; - @Autowired protected DomainService domainService; - @Autowired protected ModelService modelService; - @Autowired protected ModelRelaService modelRelaService; - @Autowired protected DimensionService dimensionService; - @Autowired protected MetricService metricService; - @Autowired protected TagMetaService tagMetaService; - @Autowired protected AuthService authService; - @Autowired protected DataSetService dataSetService; - @Autowired protected TermService termService; - @Autowired protected PluginService pluginService; - @Autowired protected DataSourceProperties dataSourceProperties; - @Autowired protected TagObjectService tagObjectService; - @Autowired protected ChatQueryService chatQueryService; - @Autowired protected ChatManageService chatManageService; - @Autowired protected AgentService agentService; - @Autowired protected SystemConfigService sysParameterService; - @Autowired protected CanvasService canvasService; - @Autowired protected DictWordService dictWordService; + @Autowired + protected DatabaseService databaseService; + @Autowired + protected DomainService domainService; + @Autowired + protected ModelService modelService; + @Autowired + protected ModelRelaService modelRelaService; + @Autowired + protected DimensionService dimensionService; + @Autowired + protected MetricService metricService; + @Autowired + protected TagMetaService tagMetaService; + @Autowired + protected AuthService authService; + @Autowired + protected DataSetService dataSetService; + @Autowired + protected TermService termService; + @Autowired + protected PluginService pluginService; + @Autowired + protected DataSourceProperties dataSourceProperties; + @Autowired + protected TagObjectService tagObjectService; + @Autowired + protected ChatQueryService chatQueryService; + @Autowired + protected ChatManageService chatManageService; + @Autowired + protected AgentService agentService; + @Autowired + protected SystemConfigService sysParameterService; + @Autowired + protected CanvasService canvasService; + @Autowired + protected DictWordService dictWordService; @Value("${s2.demo.names:S2VisitsDemo}") protected List demoList; @@ -106,8 +125,8 @@ public abstract class S2BaseDemo implements CommandLineRunner { } databaseReq.setUrl(url); databaseReq.setUsername(dataSourceProperties.getUsername()); - databaseReq.setPassword( - AESEncryptionUtil.aesEncryptECB(dataSourceProperties.getPassword())); + databaseReq + .setPassword(AESEncryptionUtil.aesEncryptECB(dataSourceProperties.getPassword())); return databaseService.createOrUpdateDatabase(databaseReq, user); } @@ -125,15 +144,11 @@ public abstract class S2BaseDemo implements CommandLineRunner { dataSetModelConfig.setId(modelResp.getId()); MetaFilter metaFilter = new MetaFilter(); metaFilter.setModelIds(Lists.newArrayList(modelResp.getId())); - List metrics = - metricService.getMetrics(metaFilter).stream() - .map(MetricResp::getId) - .collect(Collectors.toList()); + List metrics = metricService.getMetrics(metaFilter).stream() + .map(MetricResp::getId).collect(Collectors.toList()); dataSetModelConfig.setMetrics(metrics); - List dimensions = - dimensionService.getDimensions(metaFilter).stream() - .map(DimensionResp::getId) - .collect(Collectors.toList()); + List dimensions = dimensionService.getDimensions(metaFilter).stream() + .map(DimensionResp::getId).collect(Collectors.toList()); dataSetModelConfig.setMetrics(metrics); dataSetModelConfig.setDimensions(dimensions); dataSetModelConfigs.add(dataSetModelConfig); diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java index 23ad37215..9994b3cc8 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/demo/S2VisitsDemo.java @@ -150,13 +150,8 @@ public class S2VisitsDemo extends S2BaseDemo { agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计"); agent.setStatus(1); agent.setEnableSearch(1); - agent.setExamples( - Lists.newArrayList( - "超音数访问次数", - "近15天超音数访问次数汇总", - "按部门统计超音数的访问人数", - "对比alice和lucy的停留时长", - "超音数访问次数最高的部门")); + agent.setExamples(Lists.newArrayList("超音数访问次数", "近15天超音数访问次数汇总", "按部门统计超音数的访问人数", + "对比alice和lucy的停留时长", "超音数访问次数最高的部门")); AgentConfig agentConfig = new AgentConfig(); RuleParserTool ruleQueryTool = new RuleParserTool(); ruleQueryTool.setType(AgentToolType.NL2SQL_RULE); @@ -189,9 +184,8 @@ public class S2VisitsDemo extends S2BaseDemo { return domainService.createDomain(domainReq, user); } - public ModelResp addModel_1( - DomainResp s2Domain, DatabaseResp s2Database, TagObjectResp s2TagObject) - throws Exception { + public ModelResp addModel_1(DomainResp s2Domain, DatabaseResp s2Database, + TagObjectResp s2TagObject) throws Exception { ModelReq modelReq = new ModelReq(); modelReq.setName("用户部门"); modelReq.setBizName("user_department"); @@ -259,9 +253,8 @@ public class S2VisitsDemo extends S2BaseDemo { fields.add(Field.builder().fieldName("pv").dataType("Long").build()); fields.add(Field.builder().fieldName("user_id").dataType("Varchar").build()); modelDetail.setFields(fields); - modelDetail.setSqlQuery( - "SELECT imp_date, user_name, page, 1 as pv, " - + "user_name as user_id FROM s2_pv_uv_statis"); + modelDetail.setSqlQuery("SELECT imp_date, user_name, page, 1 as pv, " + + "user_name as user_id FROM s2_pv_uv_statis"); modelDetail.setQueryType("sql_query"); modelReq.setModelDetail(modelDetail); return modelService.createModel(modelReq, user); @@ -302,15 +295,15 @@ public class S2VisitsDemo extends S2BaseDemo { fields.add(Field.builder().fieldName("page").dataType("Varchar").build()); fields.add(Field.builder().fieldName("stay_hours").dataType("Double").build()); modelDetail.setFields(fields); - modelDetail.setSqlQuery( - "select imp_date,user_name,stay_hours,page from s2_stay_time_statis"); + modelDetail + .setSqlQuery("select imp_date,user_name,stay_hours,page from s2_stay_time_statis"); modelDetail.setQueryType("sql_query"); modelReq.setModelDetail(modelDetail); return modelService.createModel(modelReq, user); } - public void addModelRela_1( - DomainResp s2Domain, ModelResp userDepartmentModel, ModelResp pvUvModel) { + public void addModelRela_1(DomainResp s2Domain, ModelResp userDepartmentModel, + ModelResp pvUvModel) { List joinConditions = Lists.newArrayList(); joinConditions.add(new JoinCondition("user_name", "user_name", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); @@ -322,8 +315,8 @@ public class S2VisitsDemo extends S2BaseDemo { modelRelaService.save(modelRelaReq, user); } - public void addModelRela_2( - DomainResp s2Domain, ModelResp userDepartmentModel, ModelResp stayTimeModel) { + public void addModelRela_2(DomainResp s2Domain, ModelResp userDepartmentModel, + ModelResp stayTimeModel) { List joinConditions = Lists.newArrayList(); joinConditions.add(new JoinCondition("user_name", "user_name", FilterOperatorEnum.EQUALS)); ModelRela modelRelaReq = new ModelRela(); @@ -336,8 +329,7 @@ public class S2VisitsDemo extends S2BaseDemo { } private void addTags(ModelResp model) { - addTag( - dimensionService.getDimension("department", model.getId()).getId(), + addTag(dimensionService.getDimension("department", model.getId()).getId(), TagDefineType.DIMENSION); } @@ -358,9 +350,8 @@ public class S2VisitsDemo extends S2BaseDemo { dimensionService.updateDimension(dimensionReq, user); } - public void updateMetric( - ModelResp stayTimeModel, DimensionResp departmentDimension, DimensionResp userDimension) - throws Exception { + public void updateMetric(ModelResp stayTimeModel, DimensionResp departmentDimension, + DimensionResp userDimension) throws Exception { MetricResp stayHoursMetric = metricService.getMetric(stayTimeModel.getId(), "stay_hours"); MetricReq metricReq = new MetricReq(); metricReq.setModelId(stayTimeModel.getId()); @@ -373,25 +364,19 @@ public class S2VisitsDemo extends S2BaseDemo { MetricDefineByMeasureParams metricTypeParams = new MetricDefineByMeasureParams(); metricTypeParams.setExpr("s2_stay_time_statis_stay_hours"); List measures = new ArrayList<>(); - MeasureParam measure = - new MeasureParam( - "s2_stay_time_statis_stay_hours", "", AggOperatorEnum.SUM.getOperator()); + MeasureParam measure = new MeasureParam("s2_stay_time_statis_stay_hours", "", + AggOperatorEnum.SUM.getOperator()); measures.add(measure); metricTypeParams.setMeasures(measures); metricReq.setMetricDefineByMeasureParams(metricTypeParams); metricReq.setMetricDefineType(MetricDefineType.MEASURE); - metricReq.setRelateDimension( - getRelateDimension( - Lists.newArrayList(departmentDimension.getId(), userDimension.getId()))); + metricReq.setRelateDimension(getRelateDimension( + Lists.newArrayList(departmentDimension.getId(), userDimension.getId()))); metricService.updateMetric(metricReq, user); } - public void updateMetric_pv( - ModelResp pvUvModel, - DimensionResp departmentDimension, - DimensionResp userDimension, - MetricResp metricPv) - throws Exception { + public void updateMetric_pv(ModelResp pvUvModel, DimensionResp departmentDimension, + DimensionResp userDimension, MetricResp metricPv) throws Exception { MetricReq metricReq = new MetricReq(); metricReq.setModelId(pvUvModel.getId()); metricReq.setId(metricPv.getId()); @@ -407,9 +392,8 @@ public class S2VisitsDemo extends S2BaseDemo { metricTypeParams.setMeasures(measures); metricReq.setMetricDefineByMeasureParams(metricTypeParams); metricReq.setMetricDefineType(MetricDefineType.MEASURE); - metricReq.setRelateDimension( - getRelateDimension( - Lists.newArrayList(departmentDimension.getId(), userDimension.getId()))); + metricReq.setRelateDimension(getRelateDimension( + Lists.newArrayList(departmentDimension.getId(), userDimension.getId()))); metricService.updateMetric(metricReq, user); } @@ -434,12 +418,8 @@ public class S2VisitsDemo extends S2BaseDemo { return metricService.createMetric(metricReq, user); } - public MetricResp addMetric_pv_avg( - MetricResp metricPv, - MetricResp metricUv, - DimensionResp departmentDimension, - ModelResp pvModel) - throws Exception { + public MetricResp addMetric_pv_avg(MetricResp metricPv, MetricResp metricUv, + DimensionResp departmentDimension, ModelResp pvModel) throws Exception { MetricReq metricReq = new MetricReq(); metricReq.setModelId(pvModel.getId()); metricReq.setName("人均访问次数"); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/BaseApplication.java b/launchers/standalone/src/test/java/com/tencent/supersonic/BaseApplication.java index 3f4436d2d..cc53a796a 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/BaseApplication.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/BaseApplication.java @@ -3,4 +3,5 @@ package com.tencent.supersonic; import org.springframework.boot.test.context.SpringBootTest; @SpringBootTest(classes = {StandaloneLauncher.class}) -public class BaseApplication {} +public class BaseApplication { +} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java index 6011e83a4..fe70c5e37 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/BaseTest.java @@ -29,23 +29,19 @@ public class BaseTest extends BaseApplication { protected final String endDay = LocalDate.now().toString(); protected final DatePeriodEnum period = DatePeriodEnum.DAY; - @Autowired protected ChatQueryService chatQueryService; - @Autowired protected AgentService agentService; + @Autowired + protected ChatQueryService chatQueryService; + @Autowired + protected AgentService agentService; protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception { ParseResp parseResp = submitParse(queryText, agentId, chatId); SemanticParseInfo semanticParseInfo = parseResp.getSelectedParses().get(0); - ChatExecuteReq request = - ChatExecuteReq.builder() - .queryText(parseResp.getQueryText()) - .user(DataUtils.getUser()) - .parseId(semanticParseInfo.getId()) - .queryId(parseResp.getQueryId()) - .chatId(chatId) - .saveAnswer(true) - .build(); + ChatExecuteReq request = ChatExecuteReq.builder().queryText(parseResp.getQueryText()) + .user(DataUtils.getUser()).parseId(semanticParseInfo.getId()) + .queryId(parseResp.getQueryId()).chatId(chatId).saveAnswer(true).build(); QueryResult queryResult = chatQueryService.performExecution(request); queryResult.setChatContext(semanticParseInfo); return queryResult; @@ -56,16 +52,9 @@ public class BaseTest extends BaseApplication { ParseResp parseResp = submitParse(queryText, agentId, chatId); SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0); - ChatExecuteReq request = - ChatExecuteReq.builder() - .queryText(parseResp.getQueryText()) - .user(DataUtils.getUser()) - .parseId(parseInfo.getId()) - .agentId(agentId) - .chatId(chatId) - .queryId(parseResp.getQueryId()) - .saveAnswer(false) - .build(); + ChatExecuteReq request = ChatExecuteReq.builder().queryText(parseResp.getQueryText()) + .user(DataUtils.getUser()).parseId(parseInfo.getId()).agentId(agentId) + .chatId(chatId).queryId(parseResp.getQueryId()).saveAnswer(false).build(); QueryResult result = chatQueryService.performExecution(request); result.setChatContext(parseInfo); @@ -79,16 +68,10 @@ public class BaseTest extends BaseApplication { } protected void assertSchemaElements(Set expected, Set actual) { - Set expectedNames = - expected.stream() - .map(s -> s.getName()) - .filter(s -> s != null) - .collect(Collectors.toSet()); - Set actualNames = - actual.stream() - .map(s -> s.getName()) - .filter(s -> s != null) - .collect(Collectors.toSet()); + Set expectedNames = expected.stream().map(s -> s.getName()).filter(s -> s != null) + .collect(Collectors.toSet()); + Set actualNames = actual.stream().map(s -> s.getName()).filter(s -> s != null) + .collect(Collectors.toSet()); assertEquals(expectedNames, actualNames); } @@ -104,8 +87,8 @@ public class BaseTest extends BaseApplication { assertSchemaElements(expectedParseInfo.getMetrics(), actualParseInfo.getMetrics()); assertSchemaElements(expectedParseInfo.getDimensions(), actualParseInfo.getDimensions()); - assertEquals( - expectedParseInfo.getDimensionFilters(), actualParseInfo.getDimensionFilters()); + assertEquals(expectedParseInfo.getDimensionFilters(), + actualParseInfo.getDimensionFilters()); assertEquals(expectedParseInfo.getMetricFilters(), actualParseInfo.getMetricFilters()); assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java index 4146e6cc4..591074c66 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/DetailTest.java @@ -36,12 +36,9 @@ public class DetailTest extends BaseTest { DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L); expectedParseInfo.getDimensionFilters().add(dimensionFilter); - expectedParseInfo - .getDimensions() - .addAll( - Lists.newArrayList( - SchemaElement.builder().name("流派").build(), - SchemaElement.builder().name("代表作").build())); + expectedParseInfo.getDimensions() + .addAll(Lists.newArrayList(SchemaElement.builder().name("流派").build(), + SchemaElement.builder().name("代表作").build())); assertQueryResult(expectedResult, actualResult); } @@ -63,14 +60,11 @@ public class DetailTest extends BaseTest { expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build()); - expectedParseInfo - .getDimensions() - .addAll( - Lists.newArrayList( - SchemaElement.builder().name("歌手名").build(), - SchemaElement.builder().name("活跃区域").build(), - SchemaElement.builder().name("流派").build(), - SchemaElement.builder().name("代表作").build())); + expectedParseInfo.getDimensions() + .addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build(), + SchemaElement.builder().name("活跃区域").build(), + SchemaElement.builder().name("流派").build(), + SchemaElement.builder().name("代表作").build())); assertQueryResult(expectedResult, actualResult); } @@ -92,14 +86,11 @@ public class DetailTest extends BaseTest { expectedParseInfo.getDimensionFilters().add(dimensionFilter); expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build()); - expectedParseInfo - .getDimensions() - .addAll( - Lists.newArrayList( - SchemaElement.builder().name("歌手名").build(), - SchemaElement.builder().name("活跃区域").build(), - SchemaElement.builder().name("流派").build(), - SchemaElement.builder().name("代表作").build())); + expectedParseInfo.getDimensions() + .addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build(), + SchemaElement.builder().name("活跃区域").build(), + SchemaElement.builder().name("流派").build(), + SchemaElement.builder().name("代表作").build())); assertQueryResult(expectedResult, actualResult); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java index 59a9fce6d..036945278 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MetricTest.java @@ -38,11 +38,8 @@ public class MetricTest extends BaseTest { expectedParseInfo.setAggType(NONE); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); - expectedParseInfo - .getDimensionFilters() - .add( - DataUtils.getFilter( - "user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); + expectedParseInfo.getDimensionFilters().add( + DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); expectedParseInfo.setDateInfo( DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); @@ -65,9 +62,8 @@ public class MetricTest extends BaseTest { expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门")); - expectedParseInfo.setDateInfo( - DataUtils.getDateConf( - DateConf.DateMode.BETWEEN, 7, DatePeriodEnum.DAY, startDay, endDay)); + expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 7, + DatePeriodEnum.DAY, startDay, endDay)); expectedParseInfo.setQueryType(QueryType.AGGREGATE); assertQueryResult(expectedResult, actualResult); @@ -158,11 +154,8 @@ public class MetricTest extends BaseTest { expectedParseInfo.setAggType(NONE); expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); - expectedParseInfo - .getDimensionFilters() - .add( - DataUtils.getFilter( - "user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); + expectedParseInfo.getDimensionFilters().add( + DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); expectedParseInfo.setDateInfo( DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay)); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java index 9b86195bc..35a27bb41 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/chat/MultiTurnsTest.java @@ -17,9 +17,8 @@ public class MultiTurnsTest extends BaseTest { @Test @Order(1) public void queryTest_01() throws Exception { - QueryResult actualResult = - submitMultiTurnChat( - "alice的访问次数", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID); + QueryResult actualResult = submitMultiTurnChat("alice的访问次数", DataUtils.metricAgentId, + DataUtils.MULTI_TURNS_CHAT_ID); QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); @@ -30,11 +29,8 @@ public class MultiTurnsTest extends BaseTest { expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数")); - expectedParseInfo - .getDimensionFilters() - .add( - DataUtils.getFilter( - "user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); + expectedParseInfo.getDimensionFilters().add( + DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); expectedParseInfo.setDateInfo( DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); @@ -46,9 +42,8 @@ public class MultiTurnsTest extends BaseTest { @Test @Order(2) public void queryTest_02() throws Exception { - QueryResult actualResult = - submitMultiTurnChat( - "停留时长呢", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID); + QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId, + DataUtils.MULTI_TURNS_CHAT_ID); QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); @@ -59,11 +54,8 @@ public class MultiTurnsTest extends BaseTest { expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长")); - expectedParseInfo - .getDimensionFilters() - .add( - DataUtils.getFilter( - "user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); + expectedParseInfo.getDimensionFilters().add( + DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L)); expectedParseInfo.setDateInfo( DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay)); @@ -75,9 +67,8 @@ public class MultiTurnsTest extends BaseTest { @Test @Order(3) public void queryTest_03() throws Exception { - QueryResult actualResult = - submitMultiTurnChat( - "lucy的如何", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID); + QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId, + DataUtils.MULTI_TURNS_CHAT_ID); QueryResult expectedResult = new QueryResult(); SemanticParseInfo expectedParseInfo = new SemanticParseInfo(); @@ -88,8 +79,7 @@ public class MultiTurnsTest extends BaseTest { expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长")); - expectedParseInfo - .getDimensionFilters() + expectedParseInfo.getDimensionFilters() .add(DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "lucy", "用户", 2L)); expectedParseInfo.setDateInfo( diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java index 53deb9fde..23f1a38d1 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/evaluation/Text2SQLEval.java @@ -35,9 +35,8 @@ public class Text2SQLEval extends BaseTest { for (Long duration : durations) { total_duration += duration; } - System.out.println( - String.format( - "Avg Duration: %d seconds", total_duration / 1000 / durations.size())); + System.out.println(String.format("Avg Duration: %d seconds", + total_duration / 1000 / durations.size())); } @Test @@ -119,11 +118,8 @@ public class Text2SQLEval extends BaseTest { QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId); durations.add(System.currentTimeMillis() - start); assert result.getQueryColumns().size() >= 1; - assert result.getQueryColumns().stream() - .filter(c -> c.getName().contains("停留时长")) - .collect(Collectors.toList()) - .size() - == 1; + assert result.getQueryColumns().stream().filter(c -> c.getName().contains("停留时长")) + .collect(Collectors.toList()).size() == 1; assert result.getQueryResults().size() >= 1; } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java index 5f765a562..a28ea4648 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/BaseTest.java @@ -27,9 +27,11 @@ import static java.time.LocalDate.now; public class BaseTest extends BaseApplication { - @Autowired protected SemanticLayerService semanticLayerService; + @Autowired + protected SemanticLayerService semanticLayerService; - @Autowired private DomainRepository domainRepository; + @Autowired + private DomainRepository domainRepository; protected SemanticQueryResp queryBySql(String sql) throws Exception { return queryBySql(sql, User.getFakeUser()); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/DictTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/DictTest.java index a40329100..b807c5094 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/DictTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/DictTest.java @@ -19,9 +19,11 @@ import java.util.Date; import java.util.List; public class DictTest extends BaseTest { - @Autowired private DictConfMapper confMapper; + @Autowired + private DictConfMapper confMapper; - @Autowired private DictTaskService taskService; + @Autowired + private DictTaskService taskService; @Test public void insertConf() { @@ -80,11 +82,8 @@ public class DictTest extends BaseTest { void testAddTask() { editConf(); DictConfDO confDODb = confMapper.selectById(1L); - DictSingleTaskReq dictTask = - DictSingleTaskReq.builder() - .itemId(confDODb.getItemId()) - .type(TypeEnums.DIMENSION) - .build(); + DictSingleTaskReq dictTask = DictSingleTaskReq.builder().itemId(confDODb.getItemId()) + .type(TypeEnums.DIMENSION).build(); taskService.addDictTask(dictTask, null); DictSingleTaskReq taskReq = DictSingleTaskReq.builder().itemId(3L).type(TypeEnums.DIMENSION).build(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java index b5a6814ab..9e0f4755f 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/MetaDiscoveryTest.java @@ -14,7 +14,8 @@ import java.util.Collections; public class MetaDiscoveryTest extends BaseTest { - @Autowired protected ChatLayerService chatLayerService; + @Autowired + protected ChatLayerService chatLayerService; @Test public void testGetMapMeta() throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java index 9a4e6c989..83915bb5a 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/ModelSchemaTest.java @@ -15,7 +15,8 @@ import java.util.stream.Collectors; public class ModelSchemaTest extends BaseTest { - @Autowired private ModelService modelService; + @Autowired + private ModelService modelService; @Test void testGetUnAvailableItem() { @@ -25,10 +26,8 @@ public class ModelSchemaTest extends BaseTest { UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq); List expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L); List actualUnAvailableMetricId = - unAvailableItemResp.getMetricResps().stream() - .map(MetricResp::getId) - .sorted(Comparator.naturalOrder()) - .collect(Collectors.toList()); + unAvailableItemResp.getMetricResps().stream().map(MetricResp::getId) + .sorted(Comparator.naturalOrder()).collect(Collectors.toList()); Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java index a5ef3e3a0..f73a27511 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByMetricTest.java @@ -15,7 +15,8 @@ import static org.junit.Assert.assertThrows; public class QueryByMetricTest extends BaseTest { - @Autowired protected MetricService metricService; + @Autowired + protected MetricService metricService; @Test public void testWithMetricAndDimensionBizNames() throws Exception { @@ -50,8 +51,7 @@ public class QueryByMetricTest extends BaseTest { queryMetricReq.setDomainId(2L); queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv")); queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department")); - assertThrows( - IllegalArgumentException.class, + assertThrows(IllegalArgumentException.class, () -> queryByMetric(queryMetricReq, User.getFakeUser())); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java index 19e1ef69e..6a7ce1727 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryBySqlTest.java @@ -53,9 +53,8 @@ public class QueryBySqlTest extends BaseTest { @Test public void testFilterQuery() throws Exception { - SemanticQueryResp result = - queryBySql( - "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 "); + SemanticQueryResp result = queryBySql( + "SELECT 部门, SUM(访问次数) AS 总访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 "); assertEquals(2, result.getColumns().size()); QueryColumn firstColumn = result.getColumns().get(0); QueryColumn secondColumn = result.getColumns().get(1); @@ -101,19 +100,16 @@ public class QueryBySqlTest extends BaseTest { public void testAuthorization_model() { User alice = DataUtils.getUserAlice(); setDomainNotOpenToAll(); - assertThrows( - InvalidPermissionException.class, + assertThrows(InvalidPermissionException.class, () -> queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", alice)); } @Test public void testAuthorization_sensitive_metric() throws Exception { User tom = DataUtils.getUserTom(); - assertThrows( - InvalidPermissionException.class, - () -> - queryBySql( - "SELECT SUM(stay_hours) FROM 停留时长统计 WHERE department ='HR'", tom)); + assertThrows(InvalidPermissionException.class, + () -> queryBySql("SELECT SUM(stay_hours) FROM 停留时长统计 WHERE department ='HR'", + tom)); } @Test @@ -130,8 +126,7 @@ public class QueryBySqlTest extends BaseTest { SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", tom); Assertions.assertNotNull(semanticQueryResp.getQueryAuthorization().getMessage()); - Assertions.assertTrue( - semanticQueryResp.getSql().contains("user_name = 'tom'") - || semanticQueryResp.getSql().contains("`user_name` = 'tom'")); + Assertions.assertTrue(semanticQueryResp.getSql().contains("user_name = 'tom'") + || semanticQueryResp.getSql().contains("`user_name` = 'tom'")); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByStructTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByStructTest.java index 4e6d506cc..d3310d1e3 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByStructTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryByStructTest.java @@ -112,8 +112,7 @@ public class QueryByStructTest extends BaseTest { User alice = DataUtils.getUserAlice(); setDomainNotOpenToAll(); QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("department")); - assertThrows( - InvalidPermissionException.class, + assertThrows(InvalidPermissionException.class, () -> semanticLayerService.queryByReq(queryStructReq1, alice)); } @@ -125,8 +124,7 @@ public class QueryByStructTest extends BaseTest { aggregator.setColumn("stay_hours"); QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"), aggregator); - assertThrows( - InvalidPermissionException.class, + assertThrows(InvalidPermissionException.class, () -> semanticLayerService.queryByReq(queryStructReq, tom)); } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryRuleTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryRuleTest.java index 226fd6b29..f72e07067 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryRuleTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/QueryRuleTest.java @@ -19,7 +19,8 @@ import java.util.List; public class QueryRuleTest extends BaseTest { - @Autowired private QueryRuleService queryRuleService; + @Autowired + private QueryRuleService queryRuleService; private User user = User.getFakeUser(); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java index 6e40fff2d..de28a37e4 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/SchemaAuthTest.java @@ -19,11 +19,14 @@ import java.util.stream.Collectors; public class SchemaAuthTest extends BaseTest { - @Autowired private DomainService domainService; + @Autowired + private DomainService domainService; - @Autowired private DataSetService dataSetService; + @Autowired + private DataSetService dataSetService; - @Autowired private ModelService modelService; + @Autowired + private ModelService modelService; @Test public void test_getDomainList_alice() { @@ -31,8 +34,7 @@ public class SchemaAuthTest extends BaseTest { setDomainNotOpenToAll(); List domainResps = domainService.getDomainListWithAdminAuth(user); List expectedDomainBizNames = Lists.newArrayList("supersonic", "singer"); - Assertions.assertEquals( - expectedDomainBizNames, + Assertions.assertEquals(expectedDomainBizNames, domainResps.stream().map(DomainResp::getBizName).collect(Collectors.toList())); } @@ -41,8 +43,7 @@ public class SchemaAuthTest extends BaseTest { User user = DataUtils.getUserAlice(); List modelResps = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); List expectedModelBizNames = Lists.newArrayList("user_department", "singer"); - Assertions.assertEquals( - expectedModelBizNames, + Assertions.assertEquals(expectedModelBizNames, modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList())); } @@ -52,8 +53,7 @@ public class SchemaAuthTest extends BaseTest { List modelResps = modelService.getModelListWithAuth(user, null, AuthType.VISIBLE); List expectedModelBizNames = Lists.newArrayList("user_department", "singer"); - Assertions.assertEquals( - expectedModelBizNames, + Assertions.assertEquals(expectedModelBizNames, modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList())); } @@ -62,8 +62,7 @@ public class SchemaAuthTest extends BaseTest { User user = DataUtils.getUserAlice(); List dataSetResps = dataSetService.getDataSetsInheritAuth(user, 0L); List expectedDataSetBizNames = Lists.newArrayList("singer"); - Assertions.assertEquals( - expectedDataSetBizNames, + Assertions.assertEquals(expectedDataSetBizNames, dataSetResps.stream().map(DataSetResp::getBizName).collect(Collectors.toList())); } @@ -72,8 +71,7 @@ public class SchemaAuthTest extends BaseTest { User user = DataUtils.getUserJack(); List domainResps = domainService.getDomainListWithAdminAuth(user); List expectedDomainBizNames = Lists.newArrayList("supersonic"); - Assertions.assertEquals( - expectedDomainBizNames, + Assertions.assertEquals(expectedDomainBizNames, domainResps.stream().map(DomainResp::getBizName).collect(Collectors.toList())); } @@ -83,8 +81,7 @@ public class SchemaAuthTest extends BaseTest { List modelResps = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); List expectedModelBizNames = Lists.newArrayList("user_department", "s2_pv_uv_statis", "s2_stay_time_statis"); - Assertions.assertEquals( - expectedModelBizNames, + Assertions.assertEquals(expectedModelBizNames, modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList())); } @@ -93,8 +90,7 @@ public class SchemaAuthTest extends BaseTest { User user = DataUtils.getUserJack(); List dataSetResps = dataSetService.getDataSetsInheritAuth(user, 0L); List expectedDataSetBizNames = Lists.newArrayList("s2", "singer"); - Assertions.assertEquals( - expectedDataSetBizNames, + Assertions.assertEquals(expectedDataSetBizNames, dataSetResps.stream().map(DataSetResp::getBizName).collect(Collectors.toList())); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagObjectTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagObjectTest.java index fa3220583..328eaae4a 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagObjectTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagObjectTest.java @@ -13,7 +13,8 @@ import java.util.List; public class TagObjectTest extends BaseTest { - @Autowired private TagObjectService tagObjectService; + @Autowired + private TagObjectService tagObjectService; @Test void testCreateTagObject() throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java index 1fec08bd2..ed31e9cc0 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TagTest.java @@ -13,7 +13,8 @@ import org.springframework.beans.factory.annotation.Autowired; @TestMethodOrder(MethodOrderer.OrderAnnotation.class) public class TagTest extends BaseTest { - @Autowired private TagQueryService tagQueryService; + @Autowired + private TagQueryService tagQueryService; @Test public void testQueryTagValue() throws Exception { diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java index 00ff9d6d1..725be1417 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/headless/TranslateTest.java @@ -17,10 +17,9 @@ public class TranslateTest extends BaseTest { @Test public void testSqlExplain() throws Exception { String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 "; - SemanticTranslateResp explain = - semanticLayerService.translate( - QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()), - User.getFakeUser()); + SemanticTranslateResp explain = semanticLayerService.translate( + QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()), + User.getFakeUser()); assertNotNull(explain); assertNotNull(explain.getQuerySQL()); assertTrue(explain.getQuerySQL().contains("department")); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/provider/ModelProviderTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/provider/ModelProviderTest.java index 7736de550..95b862ee1 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/provider/ModelProviderTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/provider/ModelProviderTest.java @@ -51,11 +51,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT); ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - chatModel.generate("hi"); - }); + assertThrows(RuntimeException.class, () -> { + chatModel.generate("hi"); + }); } @Test @@ -67,11 +65,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5"); ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - chatModel.generate("hi"); - }); + assertThrows(RuntimeException.class, () -> { + chatModel.generate("hi"); + }); } @Test @@ -84,11 +80,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setApiKey(ParameterConfig.DEMO); ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - chatModel.generate("hi"); - }); + assertThrows(RuntimeException.class, () -> { + chatModel.generate("hi"); + }); } @Test @@ -100,11 +94,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setApiKey(ParameterConfig.DEMO); ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - chatModel.generate("hi"); - }); + assertThrows(RuntimeException.class, () -> { + chatModel.generate("hi"); + }); } @Test @@ -140,11 +132,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setApiKey(ParameterConfig.DEMO); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - embeddingModel.embed("hi"); - }); + assertThrows(RuntimeException.class, () -> { + embeddingModel.embed("hi"); + }); } @Test @@ -156,11 +146,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setApiKey(ParameterConfig.DEMO); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - embeddingModel.embed("hi"); - }); + assertThrows(RuntimeException.class, () -> { + embeddingModel.embed("hi"); + }); } @Test @@ -173,11 +161,9 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setSecretKey(ParameterConfig.DEMO); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - embeddingModel.embed("hi"); - }); + assertThrows(RuntimeException.class, () -> { + embeddingModel.embed("hi"); + }); } @Test @@ -189,10 +175,8 @@ public class ModelProviderTest extends BaseApplication { modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5"); EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig); - assertThrows( - RuntimeException.class, - () -> { - embeddingModel.embed("hi"); - }); + assertThrows(RuntimeException.class, () -> { + embeddingModel.embed("hi"); + }); } } diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index 68b007e4a..48440fb89 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -49,12 +49,8 @@ public class DataUtils { return SchemaElement.builder().name(name).build(); } - public static QueryFilter getFilter( - String bizName, - FilterOperatorEnum filterOperatorEnum, - Object value, - String name, - Long elementId) { + public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum, + Object value, String name, Long elementId) { QueryFilter filter = new QueryFilter(); filter.setBizName(bizName); filter.setOperator(filterOperatorEnum); @@ -64,8 +60,8 @@ public class DataUtils { return filter; } - public static DateConf getDateConf( - Integer unit, DateConf.DateMode dateMode, DatePeriodEnum period) { + public static DateConf getDateConf(Integer unit, DateConf.DateMode dateMode, + DatePeriodEnum period) { DateConf dateInfo = new DateConf(); dateInfo.setUnit(unit); dateInfo.setDateMode(dateMode); @@ -75,12 +71,8 @@ public class DataUtils { return dateInfo; } - public static DateConf getDateConf( - DateConf.DateMode dateMode, - Integer unit, - DatePeriodEnum period, - String startDate, - String endDate) { + public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit, + DatePeriodEnum period, String startDate, String endDate) { DateConf dateInfo = new DateConf(); dateInfo.setUnit(unit); dateInfo.setDateMode(dateMode); diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java index 0f43b415c..68a7a8126 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/LLMConfigUtils.java @@ -4,14 +4,9 @@ import com.tencent.supersonic.common.pojo.ChatModelConfig; public class LLMConfigUtils { public enum LLMType { - OPENAI_GPT(false), - OPENAI_MOONSHOT(false), - OPENAI_DEEPSEEK(false), - OPENAI_QWEN(false), - OPENAI_GLM(false), - OLLAMA_LLAMA3(true), - OLLAMA_QWEN2(true), - OLLAMA_QWEN25(true); + OPENAI_GPT(false), OPENAI_MOONSHOT(false), OPENAI_DEEPSEEK(false), OPENAI_QWEN( + false), OPENAI_GLM( + false), OLLAMA_LLAMA3(true), OLLAMA_QWEN2(true), OLLAMA_QWEN25(true); private boolean isOllam; @@ -70,24 +65,12 @@ public class LLMConfigUtils { ChatModelConfig chatModelConfig; if (type.isOllam) { - chatModelConfig = - ChatModelConfig.builder() - .provider("ollama") - .baseUrl(baseUrl) - .modelName(modelName) - .temperature(temperature) - .timeOut(60000L) - .build(); + chatModelConfig = ChatModelConfig.builder().provider("ollama").baseUrl(baseUrl) + .modelName(modelName).temperature(temperature).timeOut(60000L).build(); } else { chatModelConfig = - ChatModelConfig.builder() - .provider("open_ai") - .baseUrl(baseUrl) - .apiKey(apiKey) - .modelName(modelName) - .temperature(temperature) - .timeOut(60000L) - .build(); + ChatModelConfig.builder().provider("open_ai").baseUrl(baseUrl).apiKey(apiKey) + .modelName(modelName).temperature(temperature).timeOut(60000L).build(); } return chatModelConfig; diff --git a/pom.xml b/pom.xml index 943b5ee18..8879c414b 100644 --- a/pom.xml +++ b/pom.xml @@ -245,10 +245,9 @@ ${spotless.version} - - 1.7 - - + + java-formatter.xml + javax,java,scala,\#