(improvement)[build] Use Spotless to customize the code formatting (#1750)

This commit is contained in:
lexluo09
2024-10-04 00:05:04 +08:00
committed by GitHub
parent 44d1cde34f
commit 71a9954be5
521 changed files with 7811 additions and 13046 deletions

View File

@@ -7,4 +7,5 @@ import java.lang.annotation.Target;
@Target({ElementType.METHOD}) @Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME) @Retention(RetentionPolicy.RUNTIME)
public @interface AuthenticationIgnore {} public @interface AuthenticationIgnore {
}

View File

@@ -24,9 +24,8 @@ public class AuthenticationConfig {
@Value("${s2.authentication.token.default.appKey:supersonic}") @Value("${s2.authentication.token.default.appKey:supersonic}")
private String tokenDefaultAppKey; private String tokenDefaultAppKey;
@Value( @Value("${s2.authentication.token.appSecret:supersonic:WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk"
"${s2.authentication.token.appSecret:supersonic:WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk" + "783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==}")
+ "783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ==}")
private String tokenAppSecret; private String tokenAppSecret;
@Value("${s2.authentication.token.http.header.key:Authorization}") @Value("${s2.authentication.token.http.header.key:Authorization}")
@@ -48,8 +47,7 @@ public class AuthenticationConfig {
private Long tokenTimeout; private Long tokenTimeout;
public Map<String, String> getAppKeyToSecretMap() { public Map<String, String> getAppKeyToSecretMap() {
return Arrays.stream(this.tokenAppSecret.split(",")) return Arrays.stream(this.tokenAppSecret.split(",")).map(s -> s.split(":"))
.map(s -> s.split(":"))
.collect(Collectors.toMap(e -> e[0].trim(), e -> e[1].trim())); .collect(Collectors.toMap(e -> e[0].trim(), e -> e[1].trim()));
} }
} }

View File

@@ -20,8 +20,8 @@ public class User {
private Integer isAdmin; private Integer isAdmin;
public static User get( public static User get(Long id, String name, String displayName, String email,
Long id, String name, String displayName, String email, Integer isAdmin) { Integer isAdmin) {
return new User(id, name, displayName, email, isAdmin); return new User(id, name, displayName, email, isAdmin);
} }

View File

@@ -9,24 +9,14 @@ public class UserWithPassword extends User {
private String password; private String password;
public UserWithPassword( public UserWithPassword(Long id, String name, String displayName, String email, String password,
Long id,
String name,
String displayName,
String email,
String password,
Integer isAdmin) { Integer isAdmin) {
super(id, name, displayName, email, isAdmin); super(id, name, displayName, email, isAdmin);
this.password = password; this.password = password;
} }
public static UserWithPassword get( public static UserWithPassword get(Long id, String name, String displayName, String email,
Long id, String password, Integer isAdmin) {
String name,
String displayName,
String email,
String password,
Integer isAdmin) {
return new UserWithPassword(id, name, displayName, email, password, isAdmin); return new UserWithPassword(id, name, displayName, email, password, isAdmin);
} }
} }

View File

@@ -12,8 +12,8 @@ import java.util.Set;
public interface UserService { public interface UserService {
User getCurrentUser( User getCurrentUser(HttpServletRequest httpServletRequest,
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse); HttpServletResponse httpServletResponse);
List<String> getUserNames(); List<String> getUserNames();

View File

@@ -52,12 +52,10 @@ public class DefaultUserAdaptor implements UserAdaptor {
new Organization("1", "0", "SuperSonic", "SuperSonic", Lists.newArrayList(), true); new Organization("1", "0", "SuperSonic", "SuperSonic", Lists.newArrayList(), true);
Organization hr = Organization hr =
new Organization("2", "1", "Hr", "SuperSonic/Hr", Lists.newArrayList(), false); new Organization("2", "1", "Hr", "SuperSonic/Hr", Lists.newArrayList(), false);
Organization sales = Organization sales = new Organization("3", "1", "Sales", "SuperSonic/Sales",
new Organization( Lists.newArrayList(), false);
"3", "1", "Sales", "SuperSonic/Sales", Lists.newArrayList(), false); Organization marketing = new Organization("4", "1", "Marketing", "SuperSonic/Marketing",
Organization marketing = Lists.newArrayList(), false);
new Organization(
"4", "1", "Marketing", "SuperSonic/Marketing", Lists.newArrayList(), false);
List<Organization> subOrganization = Lists.newArrayList(hr, sales, marketing); List<Organization> subOrganization = Lists.newArrayList(hr, sales, marketing);
superSonic.setSubOrganizations(subOrganization); superSonic.setSubOrganizations(subOrganization);
return Lists.newArrayList(superSonic); return Lists.newArrayList(superSonic);
@@ -113,19 +111,12 @@ public class DefaultUserAdaptor implements UserAdaptor {
throw new RuntimeException("user not exist,please register"); throw new RuntimeException("user not exist,please register");
} }
try { try {
String password = String password = AESEncryptionUtil.encrypt(userReq.getPassword(),
AESEncryptionUtil.encrypt( AESEncryptionUtil.getBytesFromString(userDO.getSalt()));
userReq.getPassword(),
AESEncryptionUtil.getBytesFromString(userDO.getSalt()));
if (userDO.getPassword().equals(password)) { if (userDO.getPassword().equals(password)) {
UserWithPassword user = UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(),
UserWithPassword.get( userDO.getDisplayName(), userDO.getEmail(), userDO.getPassword(),
userDO.getId(), userDO.getIsAdmin());
userDO.getName(),
userDO.getDisplayName(),
userDO.getEmail(),
userDO.getPassword(),
userDO.getIsAdmin());
return user; return user;
} else { } else {
throw new RuntimeException("password not correct, please try again"); throw new RuntimeException("password not correct, please try again");

View File

@@ -68,8 +68,8 @@ public abstract class AuthenticationInterceptor implements HandlerInterceptor {
try { try {
if (request instanceof StandardMultipartHttpServletRequest) { if (request instanceof StandardMultipartHttpServletRequest) {
RequestFacade servletRequest = RequestFacade servletRequest =
(RequestFacade) (RequestFacade) ((StandardMultipartHttpServletRequest) request)
((StandardMultipartHttpServletRequest) request).getRequest(); .getRequest();
Class<? extends HttpServletRequest> servletRequestClazz = servletRequest.getClass(); Class<? extends HttpServletRequest> servletRequestClazz = servletRequest.getClass();
Field request1 = servletRequestClazz.getDeclaredField("request"); Field request1 = servletRequestClazz.getDeclaredField("request");
request1.setAccessible(true); request1.setAccessible(true);

View File

@@ -22,9 +22,8 @@ import java.lang.reflect.Method;
public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor { public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor {
@Override @Override
public boolean preHandle( public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
HttpServletRequest request, HttpServletResponse response, Object handler) Object handler) throws AccessException {
throws AccessException {
authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class); authenticationConfig = ContextUtils.getBean(AuthenticationConfig.class);
userServiceImpl = ContextUtils.getBean(UserServiceImpl.class); userServiceImpl = ContextUtils.getBean(UserServiceImpl.class);
userTokenUtils = ContextUtils.getBean(UserTokenUtils.class); userTokenUtils = ContextUtils.getBean(UserTokenUtils.class);
@@ -74,11 +73,9 @@ public class DefaultAuthenticationInterceptor extends AuthenticationInterceptor
} }
private void setContext(String userName, HttpServletRequest request) { private void setContext(String userName, HttpServletRequest request) {
ThreadContext threadContext = ThreadContext threadContext = ThreadContext.builder()
ThreadContext.builder() .token(request.getHeader(authenticationConfig.getTokenHttpHeaderKey()))
.token(request.getHeader(authenticationConfig.getTokenHttpHeaderKey())) .userName(userName).build();
.userName(userName)
.build();
s2ThreadContext.set(threadContext); s2ThreadContext.set(threadContext);
} }
} }

View File

@@ -13,17 +13,14 @@ public class InterceptorFactory implements WebMvcConfigurer {
private List<AuthenticationInterceptor> authenticationInterceptors; private List<AuthenticationInterceptor> authenticationInterceptors;
public InterceptorFactory() { public InterceptorFactory() {
authenticationInterceptors = authenticationInterceptors = SpringFactoriesLoader.loadFactories(
SpringFactoriesLoader.loadFactories( AuthenticationInterceptor.class, Thread.currentThread().getContextClassLoader());
AuthenticationInterceptor.class,
Thread.currentThread().getContextClassLoader());
} }
@Override @Override
public void addInterceptors(InterceptorRegistry registry) { public void addInterceptors(InterceptorRegistry registry) {
for (AuthenticationInterceptor authenticationInterceptor : authenticationInterceptors) { for (AuthenticationInterceptor authenticationInterceptor : authenticationInterceptors) {
registry.addInterceptor(authenticationInterceptor) registry.addInterceptor(authenticationInterceptor).addPathPatterns("/**")
.addPathPatterns("/**")
.excludePathPatterns("/", "/webapp/**", "/error"); .excludePathPatterns("/", "/webapp/**", "/error");
} }
} }

View File

@@ -138,8 +138,8 @@ public class UserDOExample {
criteria.add(new Criterion(condition, value)); criteria.add(new Criterion(condition, value));
} }
protected void addCriterion( protected void addCriterion(String condition, Object value1, Object value2,
String condition, Object value1, Object value2, String property) { String property) {
if (value1 == null || value2 == null) { if (value1 == null || value2 == null) {
throw new RuntimeException("Between values for " + property + " cannot be null"); throw new RuntimeException("Between values for " + property + " cannot be null");
} }
@@ -628,8 +628,8 @@ public class UserDOExample {
this(condition, value, null); this(condition, value, null);
} }
protected Criterion( protected Criterion(String condition, Object value, Object secondValue,
String condition, Object value, Object secondValue, String typeHandler) { String typeHandler) {
super(); super();
this.condition = condition; this.condition = condition;
this.value = value; this.value = value;

View File

@@ -30,8 +30,8 @@ public class UserController {
} }
@GetMapping("/getCurrentUser") @GetMapping("/getCurrentUser")
public User getCurrentUser( public User getCurrentUser(HttpServletRequest httpServletRequest,
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
return userService.getCurrentUser(httpServletRequest, httpServletResponse); return userService.getCurrentUser(httpServletRequest, httpServletResponse);
} }

View File

@@ -27,8 +27,8 @@ public class UserServiceImpl implements UserService {
} }
@Override @Override
public User getCurrentUser( public User getCurrentUser(HttpServletRequest httpServletRequest,
HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
if (user != null) { if (user != null) {
SystemConfig systemConfig = sysParameterService.getSystemConfig(); SystemConfig systemConfig = sysParameterService.getSystemConfig();

View File

@@ -18,8 +18,8 @@ public class UserStrategyFactory {
private AuthenticationConfig authenticationConfig; private AuthenticationConfig authenticationConfig;
public UserStrategyFactory( public UserStrategyFactory(AuthenticationConfig authenticationConfig,
AuthenticationConfig authenticationConfig, List<UserStrategy> userStrategyList) { List<UserStrategy> userStrategyList) {
this.authenticationConfig = authenticationConfig; this.authenticationConfig = authenticationConfig;
this.userStrategyList = userStrategyList; this.userStrategyList = userStrategyList;
} }

View File

@@ -17,8 +17,7 @@ public class ComponentFactory {
} }
private static <T> T init(Class<T> factoryType) { private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories( return SpringFactoriesLoader
factoryType, Thread.currentThread().getContextClassLoader()) .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
.get(0);
} }
} }

View File

@@ -48,8 +48,7 @@ public class UserTokenUtils {
Map<String, Object> claims = new HashMap<>(5); Map<String, Object> claims = new HashMap<>(5);
claims.put(TOKEN_USER_ID, user.getId()); claims.put(TOKEN_USER_ID, user.getId());
claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName()); claims.put(TOKEN_USER_NAME, StringUtils.isEmpty(user.getName()) ? "" : user.getName());
claims.put( claims.put(TOKEN_USER_PASSWORD,
TOKEN_USER_PASSWORD,
StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword()); StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword());
claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName()); claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName());
claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis());
@@ -83,10 +82,8 @@ public class UserTokenUtils {
String userName = String.valueOf(claims.get(TOKEN_USER_NAME)); String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
Integer isAdmin = Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null ? 0
claims.get(TOKEN_IS_ADMIN) == null : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
? 0
: Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
return User.get(userId, userName, displayName, email, isAdmin); return User.get(userId, userName, displayName, email, isAdmin);
} }
@@ -105,10 +102,8 @@ public class UserTokenUtils {
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME));
String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD)); String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD));
Integer isAdmin = Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null ? 0
claims.get(TOKEN_IS_ADMIN) == null : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
? 0
: Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString());
return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin); return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin);
} }
@@ -121,11 +116,8 @@ public class UserTokenUtils {
try { try {
String tokenSecret = getTokenSecret(appKey); String tokenSecret = getTokenSecret(appKey);
Claims claims = Claims claims =
Jwts.parser() Jwts.parser().setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
.setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8)) .build().parseClaimsJws(getTokenString(token)).getBody();
.build()
.parseClaimsJws(getTokenString(token))
.getBody();
return Optional.of(claims); return Optional.of(claims);
} catch (Exception e) { } catch (Exception e) {
log.info("can not getClaims from appKey:{} token:{}, please login", appKey, token); log.info("can not getClaims from appKey:{} token:{}, please login", appKey, token);
@@ -149,15 +141,10 @@ public class UserTokenUtils {
Date expirationDate = new Date(expiration); Date expirationDate = new Date(expiration);
String tokenSecret = getTokenSecret(appKey); String tokenSecret = getTokenSecret(appKey);
return Jwts.builder() return Jwts.builder().setClaims(claims).setSubject(claims.get(TOKEN_USER_NAME).toString())
.setClaims(claims)
.setSubject(claims.get(TOKEN_USER_NAME).toString())
.setExpiration(expirationDate) .setExpiration(expirationDate)
.signWith( .signWith(new SecretKeySpec(tokenSecret.getBytes(StandardCharsets.UTF_8),
new SecretKeySpec( SignatureAlgorithm.HS512.getJcaName()), SignatureAlgorithm.HS512)
tokenSecret.getBytes(StandardCharsets.UTF_8),
SignatureAlgorithm.HS512.getJcaName()),
SignatureAlgorithm.HS512)
.compact(); .compact();
} }

View File

@@ -31,8 +31,7 @@ public class AuthController {
} }
@GetMapping("/queryGroup") @GetMapping("/queryGroup")
public List<AuthGroup> queryAuthGroup( public List<AuthGroup> queryAuthGroup(@RequestParam("modelId") String modelId,
@RequestParam("modelId") String modelId,
@RequestParam(value = "groupId", required = false) Integer groupId) { @RequestParam(value = "groupId", required = false) Integer groupId) {
return authService.queryAuthGroups(modelId, groupId); return authService.queryAuthGroups(modelId, groupId);
} }
@@ -69,10 +68,8 @@ public class AuthController {
* @return * @return
*/ */
@PostMapping("/queryAuthorizedRes") @PostMapping("/queryAuthorizedRes")
public AuthorizedResourceResp queryAuthorizedResources( public AuthorizedResourceResp queryAuthorizedResources(@RequestBody QueryAuthResReq req,
@RequestBody QueryAuthResReq req, HttpServletRequest request, HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
return authService.queryAuthorizedResources(req, user); return authService.queryAuthorizedResources(req, user);
} }

View File

@@ -39,18 +39,15 @@ public class AuthServiceImpl implements AuthService {
List<String> rows = List<String> rows =
jdbcTemplate.queryForList("select config from s2_auth_groups", String.class); jdbcTemplate.queryForList("select config from s2_auth_groups", String.class);
Gson g = new Gson(); Gson g = new Gson();
return rows.stream() return rows.stream().map(row -> g.fromJson(row, AuthGroup.class))
.map(row -> g.fromJson(row, AuthGroup.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@Override @Override
public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) { public List<AuthGroup> queryAuthGroups(String modelId, Integer groupId) {
return load().stream() return load().stream()
.filter( .filter(group -> (Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
group -> && modelId.equals(group.getModelId().toString()))
(Objects.isNull(groupId) || groupId.equals(group.getGroupId()))
&& modelId.equals(group.getModelId().toString()))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -65,15 +62,11 @@ public class AuthServiceImpl implements AuthService {
nextGroupId = obj + 1; nextGroupId = obj + 1;
} }
group.setGroupId(nextGroupId); group.setGroupId(nextGroupId);
jdbcTemplate.update( jdbcTemplate.update("insert into s2_auth_groups (group_id, config) values (?, ?);",
"insert into s2_auth_groups (group_id, config) values (?, ?);", nextGroupId, g.toJson(group));
nextGroupId,
g.toJson(group));
} else { } else {
jdbcTemplate.update( jdbcTemplate.update("update s2_auth_groups set config = ? where group_id = ?;",
"update s2_auth_groups set config = ? where group_id = ?;", g.toJson(group), group.getGroupId());
g.toJson(group),
group.getGroupId());
} }
} }
@@ -119,30 +112,24 @@ public class AuthServiceImpl implements AuthService {
return resource; return resource;
} }
private List<AuthGroup> getAuthGroups( private List<AuthGroup> getAuthGroups(List<Long> modelIds, String userName,
List<Long> modelIds, String userName, List<String> departmentIds) { List<String> departmentIds) {
List<AuthGroup> groups = List<AuthGroup> groups = load().stream().filter(group -> {
load().stream() if (!modelIds.contains(group.getModelId())) {
.filter( return false;
group -> { }
if (!modelIds.contains(group.getModelId())) { if (!CollectionUtils.isEmpty(group.getAuthorizedUsers())
return false; && group.getAuthorizedUsers().contains(userName)) {
} return true;
if (!CollectionUtils.isEmpty(group.getAuthorizedUsers()) }
&& group.getAuthorizedUsers().contains(userName)) { for (String departmentId : departmentIds) {
return true; if (!CollectionUtils.isEmpty(group.getAuthorizedDepartmentIds())
} && group.getAuthorizedDepartmentIds().contains(departmentId)) {
for (String departmentId : departmentIds) { return true;
if (!CollectionUtils.isEmpty( }
group.getAuthorizedDepartmentIds()) }
&& group.getAuthorizedDepartmentIds() return false;
.contains(departmentId)) { }).collect(Collectors.toList());
return true;
}
}
return false;
})
.collect(Collectors.toList());
log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups); log.info("user:{} department:{} authGroups:{}", userName, departmentIds, groups);
return groups; return groups;
} }

View File

@@ -4,8 +4,7 @@ import com.tencent.supersonic.common.pojo.exception.InvalidArgumentException;
import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.StringUtils;
public enum MemoryReviewResult { public enum MemoryReviewResult {
POSITIVE, POSITIVE, NEGATIVE;
NEGATIVE;
public static MemoryReviewResult getMemoryReviewResult(String value) { public static MemoryReviewResult getMemoryReviewResult(String value) {
String validValue = StringUtils.trim(value); String validValue = StringUtils.trim(value);

View File

@@ -1,7 +1,5 @@
package com.tencent.supersonic.chat.api.pojo.enums; package com.tencent.supersonic.chat.api.pojo.enums;
public enum MemoryStatus { public enum MemoryStatus {
PENDING, PENDING, ENABLED, DISABLED;
ENABLED,
DISABLED;
} }

View File

@@ -14,7 +14,8 @@ public class KnowledgeInfoReq {
private String bizName; private String bizName;
/** type: IntentionTypeEnum temporarily only supports dimension-related information */ /** type: IntentionTypeEnum temporarily only supports dimension-related information */
@NotNull private TypeEnums type = TypeEnums.DIMENSION; @NotNull
private TypeEnums type = TypeEnums.DIMENSION;
private Boolean searchEnable = false; private Boolean searchEnable = false;

View File

@@ -43,16 +43,12 @@ public class Agent extends RecordInfo {
return Lists.newArrayList(); return Lists.newArrayList();
} }
List<Map> toolList = (List) map.get("tools"); List<Map> toolList = (List) map.get("tools");
return toolList.stream() return toolList.stream().filter(tool -> {
.filter( if (Objects.isNull(type)) {
tool -> { return true;
if (Objects.isNull(type)) { }
return true; return type.name().equals(tool.get("type"));
} }).map(JSONObject::toJSONString).collect(Collectors.toList());
return type.name().equals(tool.get("type"));
})
.map(JSONObject::toJSONString)
.collect(Collectors.toList());
} }
public boolean enableSearch() { public boolean enableSearch() {
@@ -72,8 +68,7 @@ public class Agent extends RecordInfo {
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return tools.stream() return tools.stream().map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
.map(tool -> JSONObject.parseObject(tool, NL2SQLTool.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -120,10 +115,8 @@ public class Agent extends RecordInfo {
if (CollectionUtils.isEmpty(commonAgentTools)) { if (CollectionUtils.isEmpty(commonAgentTools)) {
return new HashSet<>(); return new HashSet<>();
} }
return commonAgentTools.stream() return commonAgentTools.stream().map(NL2SQLTool::getDataSetIds)
.map(NL2SQLTool::getDataSetIds) .filter(modelIds -> !CollectionUtils.isEmpty(modelIds)).flatMap(Collection::stream)
.filter(modelIds -> !CollectionUtils.isEmpty(modelIds))
.flatMap(Collection::stream)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
} }

View File

@@ -4,9 +4,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
public enum AgentToolType { public enum AgentToolType {
NL2SQL_RULE("基于规则Text-to-SQL"), NL2SQL_RULE("基于规则Text-to-SQL"), NL2SQL_LLM("基于大模型Text-to-SQL"), PLUGIN("第三方插件");
NL2SQL_LLM("基于大模型Text-to-SQL"),
PLUGIN("第三方插件");
private String title; private String title;

View File

@@ -26,14 +26,10 @@ import static com.tencent.supersonic.chat.server.parser.ParserConfig.PARSER_MULT
public class PlainTextExecutor implements ChatQueryExecutor { public class PlainTextExecutor implements ChatQueryExecutor {
private static final String INSTRUCTION = private static final String INSTRUCTION = "" + "#Role: You are a nice person to talk to.\n"
"" + "#Task: Respond quickly and nicely to the user."
+ "#Role: You are a nice person to talk to.\n" + "#Rules: 1.ALWAYS use the same language as the input.\n" + "#History Inputs: %s\n"
+ "#Task: Respond quickly and nicely to the user." + "#Current Input: %s\n" + "#Your response: ";
+ "#Rules: 1.ALWAYS use the same language as the input.\n"
+ "#History Inputs: %s\n"
+ "#Current Input: %s\n"
+ "#Your response: ";
@Override @Override
public QueryResult execute(ExecuteContext executeContext) { public QueryResult execute(ExecuteContext executeContext) {
@@ -41,11 +37,8 @@ public class PlainTextExecutor implements ChatQueryExecutor {
return null; return null;
} }
String promptStr = String promptStr = String.format(INSTRUCTION, getHistoryInputs(executeContext),
String.format( executeContext.getQueryText());
INSTRUCTION,
getHistoryInputs(executeContext),
executeContext.getQueryText());
Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP); Prompt prompt = PromptTemplate.from(promptStr).apply(Collections.EMPTY_MAP);
AgentService agentService = ContextUtils.getBean(AgentService.class); AgentService agentService = ContextUtils.getBean(AgentService.class);
@@ -74,18 +67,15 @@ public class PlainTextExecutor implements ChatQueryExecutor {
Boolean globalMultiTurnConfig = Boolean globalMultiTurnConfig =
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = Boolean multiTurnConfig =
agentMultiTurnConfig != null agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn()
? agentMultiTurnConfig.isEnableMultiTurn()
: globalMultiTurnConfig; : globalMultiTurnConfig;
if (Boolean.TRUE.equals(multiTurnConfig)) { if (Boolean.TRUE.equals(multiTurnConfig)) {
List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5); List<QueryResp> queryResps = getHistoryQueries(executeContext.getChatId(), 5);
queryResps.stream() queryResps.stream().forEach(p -> {
.forEach( historyInput.append(p.getQueryText());
p -> { historyInput.append(";");
historyInput.append(p.getQueryText()); });
historyInput.append(";");
});
} }
return historyInput.toString(); return historyInput.toString();
@@ -93,18 +83,13 @@ public class PlainTextExecutor implements ChatQueryExecutor {
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) { private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class); ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
List<QueryResp> contextualParseInfoList = List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId).stream()
chatManageService.getChatQueries(chatId).stream() .filter(q -> Objects.nonNull(q.getQueryResult())
.filter( && q.getQueryResult().getQueryState() == QueryState.SUCCESS)
q -> .collect(Collectors.toList());
Objects.nonNull(q.getQueryResult())
&& q.getQueryResult().getQueryState()
== QueryState.SUCCESS)
.collect(Collectors.toList());
List<QueryResp> contextualList = List<QueryResp> contextualList = contextualParseInfoList.subList(0,
contextualParseInfoList.subList( Math.min(multiNum, contextualParseInfoList.size()));
0, Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList); Collections.reverse(contextualList);
return contextualList; return contextualList;

View File

@@ -31,35 +31,26 @@ public class SqlExecutor implements ChatQueryExecutor {
QueryResult queryResult = doExecute(executeContext); QueryResult queryResult = doExecute(executeContext);
if (queryResult != null) { if (queryResult != null) {
String textResult = String textResult = ResultFormatter.transform2TextNew(queryResult.getQueryColumns(),
ResultFormatter.transform2TextNew( queryResult.getQueryResults());
queryResult.getQueryColumns(), queryResult.getQueryResults());
queryResult.setTextResult(textResult); queryResult.setTextResult(textResult);
if (queryResult.getQueryState().equals(QueryState.SUCCESS) if (queryResult.getQueryState().equals(QueryState.SUCCESS)
&& queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) { && queryResult.getQueryMode().equals(LLMSqlQuery.QUERY_MODE)) {
Text2SQLExemplar exemplar = Text2SQLExemplar exemplar =
JsonUtil.toObject( JsonUtil.toObject(
JsonUtil.toString( JsonUtil.toString(executeContext.getParseInfo().getProperties()
executeContext .get(Text2SQLExemplar.PROPERTY_KEY)),
.getParseInfo()
.getProperties()
.get(Text2SQLExemplar.PROPERTY_KEY)),
Text2SQLExemplar.class); Text2SQLExemplar.class);
MemoryService memoryService = ContextUtils.getBean(MemoryService.class); MemoryService memoryService = ContextUtils.getBean(MemoryService.class);
memoryService.createMemory( memoryService.createMemory(ChatMemoryDO.builder()
ChatMemoryDO.builder() .agentId(executeContext.getAgent().getId()).status(MemoryStatus.PENDING)
.agentId(executeContext.getAgent().getId()) .question(exemplar.getQuestion()).sideInfo(exemplar.getSideInfo())
.status(MemoryStatus.PENDING) .dbSchema(exemplar.getDbSchema()).s2sql(exemplar.getSql())
.question(exemplar.getQuestion()) .createdBy(executeContext.getUser().getName())
.sideInfo(exemplar.getSideInfo()) .updatedBy(executeContext.getUser().getName()).createdAt(new Date())
.dbSchema(exemplar.getDbSchema()) .build());
.s2sql(exemplar.getSql())
.createdBy(executeContext.getUser().getName())
.updatedBy(executeContext.getUser().getName())
.createdAt(new Date())
.build());
} }
} }

View File

@@ -27,25 +27,22 @@ public class MemoryReviewTask {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String INSTRUCTION = private static final String INSTRUCTION = ""
"" + "\n#Role: You are a senior data engineer experienced in writing SQL."
+ "\n#Role: You are a senior data engineer experienced in writing SQL." + "\n#Task: Your will be provided with a user question and the SQL written by junior engineer,"
+ "\n#Task: Your will be provided with a user question and the SQL written by junior engineer," + "please take a review and give your opinion." + "\n#Rules: "
+ "please take a review and give your opinion." + "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`."
+ "\n#Rules: " + "2.NO NEED to include date filter in the where clause if not explicitly expressed in the `Question`."
+ "1.ALWAYS follow the output format: `opinion=(POSITIVE|NEGATIVE),comment=(your comment)`." + "\n#Question: %s" + "\n#Schema: %s" + "\n#SideInfo: %s" + "\n#SQL: %s"
+ "2.NO NEED to include date filter in the where clause if not explicitly expressed in the `Question`." + "\n#Response: ";
+ "\n#Question: %s"
+ "\n#Schema: %s"
+ "\n#SideInfo: %s"
+ "\n#SQL: %s"
+ "\n#Response: ";
private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)"); private static final Pattern OUTPUT_PATTERN = Pattern.compile("opinion=(.*),.*comment=(.*)");
@Autowired private MemoryService memoryService; @Autowired
private MemoryService memoryService;
@Autowired private AgentService agentService; @Autowired
private AgentService agentService;
@Scheduled(fixedDelay = 60 * 1000) @Scheduled(fixedDelay = 60 * 1000)
public void review() { public void review() {
@@ -78,8 +75,8 @@ public class MemoryReviewTask {
} }
private String createPromptString(ChatMemoryDO m) { private String createPromptString(ChatMemoryDO m) {
return String.format( return String.format(INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(),
INSTRUCTION, m.getQuestion(), m.getDbSchema(), m.getSideInfo(), m.getS2sql()); m.getS2sql());
} }
private void processResponse(String response, ChatMemoryDO m) { private void processResponse(String response, ChatMemoryDO m) {

View File

@@ -21,13 +21,10 @@ public class NL2PluginParser implements ChatQueryParser {
return; return;
} }
pluginRecognizers.forEach( pluginRecognizers.forEach(pluginRecognizer -> {
pluginRecognizer -> { pluginRecognizer.recognize(parseContext, parseResp);
pluginRecognizer.recognize(parseContext, parseResp); log.info("{} recallResult:{}", pluginRecognizer.getClass().getSimpleName(),
log.info( JsonUtil.toString(parseResp));
"{} recallResult:{}", });
pluginRecognizer.getClass().getSimpleName(),
JsonUtil.toString(parseResp));
});
} }
} }

View File

@@ -52,33 +52,27 @@ public class NL2SQLParser implements ChatQueryParser {
private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline"); private static final Logger keyPipelineLog = LoggerFactory.getLogger("keyPipeline");
private static final String REWRITE_USER_QUESTION_INSTRUCTION = private static final String REWRITE_USER_QUESTION_INSTRUCTION = ""
"" + "#Role: You are a data product manager experienced in data requirements."
+ "#Role: You are a data product manager experienced in data requirements." + "#Task: Your will be provided with current and history questions asked by a user,"
+ "#Task: Your will be provided with current and history questions asked by a user," + "along with their mapped schema elements(metric, dimension and value),"
+ "along with their mapped schema elements(metric, dimension and value)," + "please try understanding the semantics and rewrite a question." + "#Rules: "
+ "please try understanding the semantics and rewrite a question." + "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges."
+ "#Rules: " + "2.ONLY respond with the rewritten question."
+ "1.ALWAYS keep relevant entities, metrics, dimensions, values and date ranges." + "#Current Question: {{current_question}}"
+ "2.ONLY respond with the rewritten question." + "#Current Mapped Schema: {{current_schema}}"
+ "#Current Question: {{current_question}}" + "#History Question: {{history_question}}"
+ "#Current Mapped Schema: {{current_schema}}" + "#History Mapped Schema: {{history_schema}}" + "#History SQL: {{history_sql}}"
+ "#History Question: {{history_question}}" + "#Rewritten Question: ";
+ "#History Mapped Schema: {{history_schema}}"
+ "#History SQL: {{history_sql}}"
+ "#Rewritten Question: ";
private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = private static final String REWRITE_ERROR_MESSAGE_INSTRUCTION = ""
"" + "#Role: You are a data business partner who closely interacts with business people.\n"
+ "#Role: You are a data business partner who closely interacts with business people.\n" + "#Task: Your will be provided with user input, system output and some examples, "
+ "#Task: Your will be provided with user input, system output and some examples, " + "please respond shortly to teach user how to ask the right question, "
+ "please respond shortly to teach user how to ask the right question, " + "by using `Examples` as references."
+ "by using `Examples` as references." + "#Rules: ALWAYS respond with the same language as the `Input`.\n"
+ "#Rules: ALWAYS respond with the same language as the `Input`.\n" + "#Input: {{user_question}}\n" + "#Output: {{system_message}}\n"
+ "#Input: {{user_question}}\n" + "#Examples: {{examples}}\n" + "#Response: ";
+ "#Output: {{system_message}}\n"
+ "#Examples: {{examples}}\n"
+ "#Response: ";
@Override @Override
public void parse(ParseContext parseContext, ParseResp parseResp) { public void parse(ParseContext parseContext, ParseResp parseResp) {
@@ -100,13 +94,10 @@ public class NL2SQLParser implements ChatQueryParser {
parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses()); parseResp.getSelectedParses().addAll(text2SqlParseResp.getSelectedParses());
} else { } else {
if (parseContext.enbaleLLM()) { if (parseContext.enbaleLLM()) {
parseResp.setErrorMsg( parseResp.setErrorMsg(rewriteErrorMessage(parseContext.getQueryText(),
rewriteErrorMessage( text2SqlParseResp.getErrorMsg(), queryNLReq.getDynamicExemplars(),
parseContext.getQueryText(), parseContext.getAgent().getExamples(),
text2SqlParseResp.getErrorMsg(), parseContext.getAgent().getModelConfig()));
queryNLReq.getDynamicExemplars(),
parseContext.getAgent().getExamples(),
parseContext.getAgent().getModelConfig()));
} }
} }
parseResp.setState(text2SqlParseResp.getState()); parseResp.setState(text2SqlParseResp.getState());
@@ -141,40 +132,26 @@ public class NL2SQLParser implements ChatQueryParser {
StringBuilder textBuilder = new StringBuilder(); StringBuilder textBuilder = new StringBuilder();
textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" "); textBuilder.append("**数据集:** ").append(parseInfo.getDataSet().getName()).append(" ");
Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst(); Optional<SchemaElement> metric = parseInfo.getMetrics().stream().findFirst();
metric.ifPresent( metric.ifPresent(schemaElement -> textBuilder.append("**指标:** ")
schemaElement -> .append(schemaElement.getName()).append(" "));
textBuilder.append("**指标:** ").append(schemaElement.getName()).append(" ")); List<String> dimensionNames = parseInfo.getDimensions().stream().map(SchemaElement::getName)
List<String> dimensionNames = .filter(Objects::nonNull).collect(Collectors.toList());
parseInfo.getDimensions().stream()
.map(SchemaElement::getName)
.filter(Objects::nonNull)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(dimensionNames)) { if (!CollectionUtils.isEmpty(dimensionNames)) {
textBuilder.append("**维度:** ").append(String.join(",", dimensionNames)); textBuilder.append("**维度:** ").append(String.join(",", dimensionNames));
} }
textBuilder.append("\n\n**筛选条件:** \n"); textBuilder.append("\n\n**筛选条件:** \n");
if (parseInfo.getDateInfo() != null) { if (parseInfo.getDateInfo() != null) {
textBuilder textBuilder.append("**数据时间:** ").append(parseInfo.getDateInfo().getStartDate())
.append("**数据时间:** ") .append("~").append(parseInfo.getDateInfo().getEndDate()).append(" ");
.append(parseInfo.getDateInfo().getStartDate())
.append("~")
.append(parseInfo.getDateInfo().getEndDate())
.append(" ");
} }
if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters()) if (!CollectionUtils.isEmpty(parseInfo.getDimensionFilters())
|| CollectionUtils.isEmpty(parseInfo.getMetricFilters())) { || CollectionUtils.isEmpty(parseInfo.getMetricFilters())) {
Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters(); Set<QueryFilter> queryFilters = parseInfo.getDimensionFilters();
queryFilters.addAll(parseInfo.getMetricFilters()); queryFilters.addAll(parseInfo.getMetricFilters());
for (QueryFilter queryFilter : queryFilters) { for (QueryFilter queryFilter : queryFilters) {
textBuilder textBuilder.append("**").append(queryFilter.getName()).append("**").append(" ")
.append("**") .append(queryFilter.getOperator().getValue()).append(" ")
.append(queryFilter.getName()) .append(queryFilter.getValue()).append(" ");
.append("**")
.append(" ")
.append(queryFilter.getOperator().getValue())
.append(" ")
.append(queryFilter.getValue())
.append(" ");
} }
} }
parseInfo.setTextInfo(textBuilder.toString()); parseInfo.setTextInfo(textBuilder.toString());
@@ -187,8 +164,7 @@ public class NL2SQLParser implements ChatQueryParser {
Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE)); Boolean.valueOf(parserConfig.getParameterValue(PARSER_MULTI_TURN_ENABLE));
Boolean multiTurnConfig = Boolean multiTurnConfig =
agentMultiTurnConfig != null agentMultiTurnConfig != null ? agentMultiTurnConfig.isEnableMultiTurn()
? agentMultiTurnConfig.isEnableMultiTurn()
: globalMultiTurnConfig; : globalMultiTurnConfig;
if (!Boolean.TRUE.equals(multiTurnConfig)) { if (!Boolean.TRUE.equals(multiTurnConfig)) {
return; return;
@@ -232,30 +208,20 @@ public class NL2SQLParser implements ChatQueryParser {
QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext); QueryNLReq rewrittenQueryNLReq = QueryReqConverter.buildText2SqlQueryReq(parseContext);
MapResp rewrittenQueryMapResult = chatLayerService.performMapping(rewrittenQueryNLReq); MapResp rewrittenQueryMapResult = chatLayerService.performMapping(rewrittenQueryNLReq);
parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo()); parseContext.setMapInfo(rewrittenQueryMapResult.getMapInfo());
log.info( log.info("Last Query: {} Current Query: {}, Rewritten Query: {}", lastQuery.getQueryText(),
"Last Query: {} Current Query: {}, Rewritten Query: {}", currentMapResult.getQueryText(), rewrittenQuery);
lastQuery.getQueryText(),
currentMapResult.getQueryText(),
rewrittenQuery);
} }
private String rewriteErrorMessage( private String rewriteErrorMessage(String userQuestion, String errMsg,
String userQuestion, List<Text2SQLExemplar> similarExemplars, List<String> agentExamples,
String errMsg,
List<Text2SQLExemplar> similarExemplars,
List<String> agentExamples,
ChatModelConfig modelConfig) { ChatModelConfig modelConfig) {
Map<String, Object> variables = new HashMap<>(); Map<String, Object> variables = new HashMap<>();
variables.put("user_question", userQuestion); variables.put("user_question", userQuestion);
variables.put("system_message", errMsg); variables.put("system_message", errMsg);
StringBuilder exampleStr = new StringBuilder(); StringBuilder exampleStr = new StringBuilder();
similarExemplars.forEach( similarExemplars.forEach(e -> exampleStr.append(
e -> String.format("<Question:{%s},Schema:{%s}> ", e.getQuestion(), e.getDbSchema())));
exampleStr.append(
String.format(
"<Question:{%s},Schema:{%s}> ",
e.getQuestion(), e.getDbSchema())));
agentExamples.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e))); agentExamples.forEach(e -> exampleStr.append(String.format("<Question:{%s}> ", e)));
variables.put("examples", exampleStr); variables.put("examples", exampleStr);
@@ -297,18 +263,13 @@ public class NL2SQLParser implements ChatQueryParser {
private List<QueryResp> getHistoryQueries(int chatId, int multiNum) { private List<QueryResp> getHistoryQueries(int chatId, int multiNum) {
ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class); ChatManageService chatManageService = ContextUtils.getBean(ChatManageService.class);
List<QueryResp> contextualParseInfoList = List<QueryResp> contextualParseInfoList = chatManageService.getChatQueries(chatId).stream()
chatManageService.getChatQueries(chatId).stream() .filter(q -> Objects.nonNull(q.getQueryResult())
.filter( && q.getQueryResult().getQueryState() == QueryState.SUCCESS)
q -> .collect(Collectors.toList());
Objects.nonNull(q.getQueryResult())
&& q.getQueryResult().getQueryState()
== QueryState.SUCCESS)
.collect(Collectors.toList());
List<QueryResp> contextualList = List<QueryResp> contextualList = contextualParseInfoList.subList(0,
contextualParseInfoList.subList( Math.min(multiNum, contextualParseInfoList.size()));
0, Math.min(multiNum, contextualParseInfoList.size()));
Collections.reverse(contextualList); Collections.reverse(contextualList);
return contextualList; return contextualList;
} }
@@ -320,9 +281,8 @@ public class NL2SQLParser implements ChatQueryParser {
ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class); ParserConfig parserConfig = ContextUtils.getBean(ParserConfig.class);
int exemplarRecallNumber = int exemplarRecallNumber =
Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER)); Integer.valueOf(parserConfig.getParameterValue(PARSER_EXEMPLAR_RECALL_NUMBER));
List<Text2SQLExemplar> exemplars = List<Text2SQLExemplar> exemplars = exemplarManager.recallExemplars(memoryCollectionName,
exemplarManager.recallExemplars( queryNLReq.getQueryText(), exemplarRecallNumber);
memoryCollectionName, queryNLReq.getQueryText(), exemplarRecallNumber);
queryNLReq.getDynamicExemplars().addAll(exemplars); queryNLReq.getDynamicExemplars().addAll(exemplars);
} }
} }

View File

@@ -10,11 +10,6 @@ import org.springframework.stereotype.Service;
public class ParserConfig extends ParameterConfig { public class ParserConfig extends ParameterConfig {
public static final Parameter PARSER_MULTI_TURN_ENABLE = public static final Parameter PARSER_MULTI_TURN_ENABLE =
new Parameter( new Parameter("s2.parser.multi-turn.enable", "false", "是否开启多轮对话", "开启多轮对话将消耗更多token",
"s2.parser.multi-turn.enable", "bool", "Parser相关配置");
"false",
"是否开启多轮对话",
"开启多轮对话将消耗更多token",
"bool",
"Parser相关配置");
} }

View File

@@ -1,10 +1,7 @@
package com.tencent.supersonic.chat.server.persistence.dataobject; package com.tencent.supersonic.chat.server.persistence.dataobject;
public enum CostType { public enum CostType {
MAPPER(1, "mapper"), MAPPER(1, "mapper"), PARSER(2, "parser"), QUERY(3, "query"), PROCESSOR(4, "processor");
PARSER(2, "parser"),
QUERY(3, "query"),
PROCESSOR(4, "processor");
private Integer type; private Integer type;
private String name; private String name;

View File

@@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.AgentDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface AgentDOMapper extends BaseMapper<AgentDO> {} public interface AgentDOMapper extends BaseMapper<AgentDO> {
}

View File

@@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatMemoryDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {} public interface ChatMemoryMapper extends BaseMapper<ChatMemoryDO> {
}

View File

@@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.ChatQueryDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface ChatQueryDOMapper extends BaseMapper<ChatQueryDO> {} public interface ChatQueryDOMapper extends BaseMapper<ChatQueryDO> {
}

View File

@@ -5,4 +5,5 @@ import com.tencent.supersonic.chat.server.persistence.dataobject.PluginDO;
import org.apache.ibatis.annotations.Mapper; import org.apache.ibatis.annotations.Mapper;
@Mapper @Mapper
public interface PluginDOMapper extends BaseMapper<PluginDO> {} public interface PluginDOMapper extends BaseMapper<PluginDO> {
}

View File

@@ -30,9 +30,7 @@ public interface ChatQueryRepository {
Long createChatQuery(ChatParseReq chatParseReq); Long createChatQuery(ChatParseReq chatParseReq);
List<ChatParseDO> batchSaveParseInfo( List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
ChatParseReq chatParseReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses); List<SemanticParseInfo> candidateParses);
ChatParseDO getParseInfo(Long questionId, int parseId); ChatParseDO getParseInfo(Long questionId, int parseId);

View File

@@ -23,8 +23,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
private final ChatConfigHelper chatConfigHelper; private final ChatConfigHelper chatConfigHelper;
private final ChatConfigMapper chatConfigMapper; private final ChatConfigMapper chatConfigMapper;
public ChatConfigRepositoryImpl( public ChatConfigRepositoryImpl(ChatConfigHelper chatConfigHelper,
ChatConfigHelper chatConfigHelper, ChatConfigMapper chatConfigMapper) { ChatConfigMapper chatConfigMapper) {
this.chatConfigHelper = chatConfigHelper; this.chatConfigHelper = chatConfigHelper;
this.chatConfigMapper = chatConfigMapper; this.chatConfigMapper = chatConfigMapper;
} }
@@ -52,11 +52,8 @@ public class ChatConfigRepositoryImpl implements ChatConfigRepository {
List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal); List<ChatConfigDO> chaConfigDOList = chatConfigMapper.search(filterInternal);
if (!CollectionUtils.isEmpty(chaConfigDOList)) { if (!CollectionUtils.isEmpty(chaConfigDOList)) {
chaConfigDOList.stream() chaConfigDOList.stream()
.forEach( .forEach(chaConfigDO -> chaConfigDescriptorList.add(chatConfigHelper
chaConfigDO -> .chatConfigDO2Descriptor(chaConfigDO.getModelId(), chaConfigDO)));
chaConfigDescriptorList.add(
chatConfigHelper.chatConfigDO2Descriptor(
chaConfigDO.getModelId(), chaConfigDO)));
} }
return chaConfigDescriptorList; return chaConfigDescriptorList;
} }

View File

@@ -40,11 +40,14 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class ChatQueryRepositoryImpl implements ChatQueryRepository { public class ChatQueryRepositoryImpl implements ChatQueryRepository {
@Autowired private ChatQueryDOMapper chatQueryDOMapper; @Autowired
private ChatQueryDOMapper chatQueryDOMapper;
@Autowired private ChatParseMapper chatParseMapper; @Autowired
private ChatParseMapper chatParseMapper;
@Autowired private ShowCaseCustomMapper showCaseCustomMapper; @Autowired
private ShowCaseCustomMapper showCaseCustomMapper;
@Override @Override
public PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoReq, Long chatId) { public PageInfo<QueryResp> getChatQuery(PageQueryInfoReq pageQueryInfoReq, Long chatId) {
@@ -67,11 +70,9 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
.doSelectPageInfo(() -> chatQueryDOMapper.selectList(queryWrapper)); .doSelectPageInfo(() -> chatQueryDOMapper.selectList(queryWrapper));
PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo); PageInfo<QueryResp> chatQueryVOPageInfo = PageUtils.pageInfo2PageInfoVo(pageInfo);
chatQueryVOPageInfo.setList( chatQueryVOPageInfo.setList(pageInfo.getList().stream()
pageInfo.getList().stream() .sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue()))
.sorted(Comparator.comparingInt(o -> o.getQuestionId().intValue())) .map(this::convertTo).collect(Collectors.toList()));
.map(this::convertTo)
.collect(Collectors.toList()));
return chatQueryVOPageInfo; return chatQueryVOPageInfo;
} }
@@ -94,22 +95,16 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
QueryWrapper<ChatQueryDO> queryWrapper = new QueryWrapper<>(); QueryWrapper<ChatQueryDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId); queryWrapper.lambda().eq(ChatQueryDO::getChatId, chatId);
queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId); queryWrapper.lambda().orderByDesc(ChatQueryDO::getQuestionId);
return chatQueryDOMapper.selectList(queryWrapper).stream() return chatQueryDOMapper.selectList(queryWrapper).stream().map(q -> convertTo(q))
.map(q -> convertTo(q))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@Override @Override
public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) { public List<QueryResp> queryShowCase(PageQueryInfoReq pageQueryInfoReq, int agentId) {
return showCaseCustomMapper return showCaseCustomMapper
.queryShowCase( .queryShowCase(pageQueryInfoReq.getLimitStart(), pageQueryInfoReq.getPageSize(),
pageQueryInfoReq.getLimitStart(), agentId, pageQueryInfoReq.getUserName())
pageQueryInfoReq.getPageSize(), .stream().map(this::convertTo).collect(Collectors.toList());
agentId,
pageQueryInfoReq.getUserName())
.stream()
.map(this::convertTo)
.collect(Collectors.toList());
} }
private QueryResp convertTo(ChatQueryDO chatQueryDO) { private QueryResp convertTo(ChatQueryDO chatQueryDO) {
@@ -121,9 +116,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
queryResult.setQueryId(chatQueryDO.getQuestionId()); queryResult.setQueryId(chatQueryDO.getQuestionId());
queryResp.setQueryResult(queryResult); queryResp.setQueryResult(queryResult);
} }
queryResp.setSimilarQueries( queryResp.setSimilarQueries(JSONObject.parseArray(chatQueryDO.getSimilarQueries(),
JSONObject.parseArray( SimilarQueryRecallResp.class));
chatQueryDO.getSimilarQueries(), SimilarQueryRecallResp.class));
queryResp.setParseTimeCost( queryResp.setParseTimeCost(
JsonUtil.toObject(chatQueryDO.getParseTimeCost(), ParseTimeCostResp.class)); JsonUtil.toObject(chatQueryDO.getParseTimeCost(), ParseTimeCostResp.class));
return queryResp; return queryResp;
@@ -147,9 +141,7 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
} }
@Override @Override
public List<ChatParseDO> batchSaveParseInfo( public List<ChatParseDO> batchSaveParseInfo(ChatParseReq chatParseReq, ParseResp parseResult,
ChatParseReq chatParseReq,
ParseResp parseResult,
List<SemanticParseInfo> candidateParses) { List<SemanticParseInfo> candidateParses) {
List<ChatParseDO> chatParseDOList = new ArrayList<>(); List<ChatParseDO> chatParseDOList = new ArrayList<>();
getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList); getChatParseDO(chatParseReq, parseResult.getQueryId(), candidateParses, chatParseDOList);
@@ -159,11 +151,8 @@ public class ChatQueryRepositoryImpl implements ChatQueryRepository {
return chatParseDOList; return chatParseDOList;
} }
public void getChatParseDO( public void getChatParseDO(ChatParseReq chatParseReq, Long queryId,
ChatParseReq chatParseReq, List<SemanticParseInfo> parses, List<ChatParseDO> chatParseDOList) {
Long queryId,
List<SemanticParseInfo> parses,
List<ChatParseDO> chatParseDOList) {
for (int i = 0; i < parses.size(); i++) { for (int i = 0; i < parses.size(); i++) {
ChatParseDO chatParseDO = new ChatParseDO(); ChatParseDO chatParseDO = new ChatParseDO();
chatParseDO.setChatId(chatParseReq.getChatId()); chatParseDO.setChatId(chatParseReq.getChatId());

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.chat.server.plugin; package com.tencent.supersonic.chat.server.plugin;
public enum ParseMode { public enum ParseMode {
EMBEDDING_RECALL, EMBEDDING_RECALL, FUNCTION_CALL;
FUNCTION_CALL;
} }

View File

@@ -46,9 +46,11 @@ import java.util.stream.Collectors;
@Component @Component
public class PluginManager { public class PluginManager {
@Autowired private EmbeddingConfig embeddingConfig; @Autowired
private EmbeddingConfig embeddingConfig;
@Autowired private EmbeddingService embeddingService; @Autowired
private EmbeddingService embeddingService;
public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) { public static List<ChatPlugin> getPluginAgentCanSupport(ParseContext parseContext) {
PluginService pluginService = ContextUtils.getBean(PluginService.class); PluginService pluginService = ContextUtils.getBean(PluginService.class);
@@ -57,21 +59,14 @@ public class PluginManager {
if (Objects.isNull(agent)) { if (Objects.isNull(agent)) {
return plugins; return plugins;
} }
List<Long> pluginIds = List<Long> pluginIds = getPluginTools(agent).stream().map(PluginTool::getPlugins)
getPluginTools(agent).stream() .flatMap(Collection::stream).collect(Collectors.toList());
.map(PluginTool::getPlugins)
.flatMap(Collection::stream)
.collect(Collectors.toList());
if (CollectionUtils.isEmpty(pluginIds)) { if (CollectionUtils.isEmpty(pluginIds)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
plugins = plugins = plugins.stream().filter(plugin -> pluginIds.contains(plugin.getId()))
plugins.stream() .collect(Collectors.toList());
.filter(plugin -> pluginIds.contains(plugin.getId())) log.info("plugins witch can be supported by cur agent :{} {}", agent.getName(),
.collect(Collectors.toList());
log.info(
"plugins witch can be supported by cur agent :{} {}",
agent.getName(),
plugins.stream().map(ChatPlugin::getName).collect(Collectors.toList())); plugins.stream().map(ChatPlugin::getName).collect(Collectors.toList()));
return plugins; return plugins;
} }
@@ -84,8 +79,7 @@ public class PluginManager {
if (CollectionUtils.isEmpty(tools)) { if (CollectionUtils.isEmpty(tools)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return tools.stream() return tools.stream().map(tool -> JSONObject.parseObject(tool, PluginTool.class))
.map(tool -> JSONObject.parseObject(tool, PluginTool.class))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
@@ -142,23 +136,18 @@ public class PluginManager {
public RetrieveQueryResult recognize(String embeddingText) { public RetrieveQueryResult recognize(String embeddingText) {
RetrieveQuery retrieveQuery = RetrieveQuery retrieveQuery = RetrieveQuery.builder()
RetrieveQuery.builder() .queryTextsList(Collections.singletonList(embeddingText)).build();
.queryTextsList(Collections.singletonList(embeddingText))
.build();
List<RetrieveQueryResult> resultList = List<RetrieveQueryResult> resultList = embeddingService.retrieveQuery(
embeddingService.retrieveQuery( embeddingConfig.getPresetCollection(), retrieveQuery, embeddingConfig.getNResult());
embeddingConfig.getPresetCollection(),
retrieveQuery,
embeddingConfig.getNResult());
if (CollectionUtils.isNotEmpty(resultList)) { if (CollectionUtils.isNotEmpty(resultList)) {
for (RetrieveQueryResult embeddingResp : resultList) { for (RetrieveQueryResult embeddingResp : resultList) {
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval(); List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
for (Retrieval embeddingRetrieval : embeddingRetrievals) { for (Retrieval embeddingRetrieval : embeddingRetrievals) {
embeddingRetrieval.setId( embeddingRetrieval
getPluginIdFromEmbeddingId(embeddingRetrieval.getId())); .setId(getPluginIdFromEmbeddingId(embeddingRetrieval.getId()));
} }
} }
return resultList.get(0); return resultList.get(0);
@@ -173,8 +162,8 @@ public class PluginManager {
int num = 0; int num = 0;
for (String pattern : exampleQuestions) { for (String pattern : exampleQuestions) {
TextSegment query = TextSegment.from(pattern); TextSegment query = TextSegment.from(pattern);
TextSegmentConvert.addQueryId( TextSegmentConvert.addQueryId(query,
query, generateUniqueEmbeddingId(num, plugin.getId())); generateUniqueEmbeddingId(num, plugin.getId()));
queries.add(query); queries.add(query);
num++; num++;
} }
@@ -250,14 +239,10 @@ public class PluginManager {
return Sets.newHashSet(); return Sets.newHashSet();
} }
return schemaElementMatches.stream() return schemaElementMatches.stream()
.filter( .filter(schemaElementMatch -> SchemaElementType.VALUE
schemaElementMatch -> .equals(schemaElementMatch.getElement().getType())
SchemaElementType.VALUE.equals( || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
schemaElementMatch.getElement().getType()) .map(SchemaElementMatch::getElement).map(SchemaElement::getId)
|| SchemaElementType.ID.equals(
schemaElementMatch.getElement().getType()))
.map(SchemaElementMatch::getElement)
.map(SchemaElement::getId)
.collect(Collectors.toSet()); .collect(Collectors.toSet());
} }
@@ -270,10 +255,8 @@ public class PluginManager {
if (CollectionUtils.isEmpty(paramOptions)) { if (CollectionUtils.isEmpty(paramOptions)) {
return Lists.newArrayList(); return Lists.newArrayList();
} }
return paramOptions.stream() return paramOptions.stream().filter(
.filter( paramOption -> ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
paramOption ->
ParamOption.ParamType.SEMANTIC.equals(paramOption.getParamType()))
.collect(Collectors.toList()); .collect(Collectors.toList());
} }

View File

@@ -26,13 +26,10 @@ public class ParamOption {
* forward * forward
*/ */
public enum ParamType { public enum ParamType {
CUSTOM, CUSTOM, SEMANTIC, FORWARD
SEMANTIC,
FORWARD
} }
public enum OptionType { public enum OptionType {
REQUIRED, REQUIRED, OPTIONAL
OPTIONAL
} }
} }

View File

@@ -43,40 +43,31 @@ public abstract class PluginSemanticQuery {
protected Map<String, Object> getElementMap(PluginParseResult pluginParseResult) { protected Map<String, Object> getElementMap(PluginParseResult pluginParseResult) {
Map<String, Object> elementValueMap = new HashMap<>(); Map<String, Object> elementValueMap = new HashMap<>();
Map<Long, Object> filterValueMap = getFilterMap(pluginParseResult); Map<Long, Object> filterValueMap = getFilterMap(pluginParseResult);
List<SchemaElementMatch> schemaElementMatchList = List<SchemaElementMatch> schemaElementMatchList = parseInfo.getElementMatches().stream()
parseInfo.getElementMatches().stream() .filter(schemaElementMatch -> schemaElementMatch.getFrequency() != null)
.filter(schemaElementMatch -> schemaElementMatch.getFrequency() != null) .sorted(Comparator.comparingLong(SchemaElementMatch::getFrequency).reversed())
.sorted( .collect(Collectors.toList());
Comparator.comparingLong(SchemaElementMatch::getFrequency)
.reversed())
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(schemaElementMatchList)) { if (!CollectionUtils.isEmpty(schemaElementMatchList)) {
schemaElementMatchList.stream() schemaElementMatchList.stream().filter(schemaElementMatch -> SchemaElementType.VALUE
.filter( .equals(schemaElementMatch.getElement().getType())
schemaElementMatch -> || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
SchemaElementType.VALUE.equals(
schemaElementMatch.getElement().getType())
|| SchemaElementType.ID.equals(
schemaElementMatch.getElement().getType()))
.filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0) .filter(schemaElementMatch -> schemaElementMatch.getSimilarity() == 1.0)
.forEach( .forEach(schemaElementMatch -> {
schemaElementMatch -> { Object queryFilterValue =
Object queryFilterValue = filterValueMap.get(schemaElementMatch.getElement().getId());
filterValueMap.get(schemaElementMatch.getElement().getId()); if (queryFilterValue != null) {
if (queryFilterValue != null) { if (String.valueOf(queryFilterValue)
if (String.valueOf(queryFilterValue) .equals(String.valueOf(schemaElementMatch.getWord()))) {
.equals(String.valueOf(schemaElementMatch.getWord()))) { elementValueMap.put(
elementValueMap.put( String.valueOf(schemaElementMatch.getElement().getId()),
String.valueOf( schemaElementMatch.getWord());
schemaElementMatch.getElement().getId()), }
schemaElementMatch.getWord()); } else {
} elementValueMap.computeIfAbsent(
} else { String.valueOf(schemaElementMatch.getElement().getId()),
elementValueMap.computeIfAbsent( k -> schemaElementMatch.getWord());
String.valueOf(schemaElementMatch.getElement().getId()), }
k -> schemaElementMatch.getWord()); });
}
});
} }
return elementValueMap; return elementValueMap;
} }

View File

@@ -41,10 +41,8 @@ public class WebPageQuery extends PluginSemanticQuery {
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
queryResult.setQueryMode(QUERY_MODE); queryResult.setQueryMode(QUERY_MODE);
Map<String, Object> properties = parseInfo.getProperties(); Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = PluginParseResult pluginParseResult = JsonUtil.toObject(
JsonUtil.toObject( JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
JsonUtil.toString(properties.get(Constants.CONTEXT)),
PluginParseResult.class);
WebPageResp webPageResponse = buildResponse(pluginParseResult); WebPageResp webPageResponse = buildResponse(pluginParseResult);
queryResult.setResponse(webPageResponse); queryResult.setResponse(webPageResponse);
queryResult.setQueryState(QueryState.SUCCESS); queryResult.setQueryState(QueryState.SUCCESS);

View File

@@ -45,10 +45,8 @@ public class WebServiceQuery extends PluginSemanticQuery {
QueryResult queryResult = new QueryResult(); QueryResult queryResult = new QueryResult();
queryResult.setQueryMode(QUERY_MODE); queryResult.setQueryMode(QUERY_MODE);
Map<String, Object> properties = parseInfo.getProperties(); Map<String, Object> properties = parseInfo.getProperties();
PluginParseResult pluginParseResult = PluginParseResult pluginParseResult = JsonUtil.toObject(
JsonUtil.toObject( JsonUtil.toString(properties.get(Constants.CONTEXT)), PluginParseResult.class);
JsonUtil.toString(properties.get(Constants.CONTEXT)),
PluginParseResult.class);
WebServiceResp webServiceResponse = buildResponse(pluginParseResult); WebServiceResp webServiceResponse = buildResponse(pluginParseResult);
Object object = webServiceResponse.getResult(); Object object = webServiceResponse.getResult();
// in order to show webServiceQuery result int frontend conveniently, // in order to show webServiceQuery result int frontend conveniently,
@@ -74,9 +72,8 @@ public class WebServiceQuery extends PluginSemanticQuery {
protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) { protected WebServiceResp buildResponse(PluginParseResult pluginParseResult) {
WebServiceResp webServiceResponse = new WebServiceResp(); WebServiceResp webServiceResponse = new WebServiceResp();
ChatPlugin plugin = pluginParseResult.getPlugin(); ChatPlugin plugin = pluginParseResult.getPlugin();
WebBase webBase = WebBase webBase = fillWebBaseResult(JsonUtil.toObject(plugin.getConfig(), WebBase.class),
fillWebBaseResult( pluginParseResult);
JsonUtil.toObject(plugin.getConfig(), WebBase.class), pluginParseResult);
webServiceResponse.setWebBase(webBase); webServiceResponse.setWebBase(webBase);
List<ParamOption> paramOptions = webBase.getParamOptions(); List<ParamOption> paramOptions = webBase.getParamOptions();
Map<String, Object> params = new HashMap<>(); Map<String, Object> params = new HashMap<>();

View File

@@ -41,17 +41,16 @@ public abstract class PluginRecognizer {
public abstract PluginRecallResult recallPlugin(ParseContext parseContext); public abstract PluginRecallResult recallPlugin(ParseContext parseContext);
public void buildQuery( public void buildQuery(ParseContext parseContext, ParseResp parseResp,
ParseContext parseContext, ParseResp parseResp, PluginRecallResult pluginRecallResult) { PluginRecallResult pluginRecallResult) {
ChatPlugin plugin = pluginRecallResult.getPlugin(); ChatPlugin plugin = pluginRecallResult.getPlugin();
Set<Long> dataSetIds = pluginRecallResult.getDataSetIds(); Set<Long> dataSetIds = pluginRecallResult.getDataSetIds();
if (plugin.isContainsAllDataSet()) { if (plugin.isContainsAllDataSet()) {
dataSetIds = Sets.newHashSet(-1L); dataSetIds = Sets.newHashSet(-1L);
} }
for (Long dataSetId : dataSetIds) { for (Long dataSetId : dataSetIds) {
SemanticParseInfo semanticParseInfo = SemanticParseInfo semanticParseInfo = buildSemanticParseInfo(dataSetId, plugin,
buildSemanticParseInfo( parseContext, pluginRecallResult.getDistance());
dataSetId, plugin, parseContext, pluginRecallResult.getDistance());
semanticParseInfo.setQueryMode(plugin.getType()); semanticParseInfo.setQueryMode(plugin.getType());
semanticParseInfo.setScore(pluginRecallResult.getScore()); semanticParseInfo.setScore(pluginRecallResult.getScore());
parseResp.getSelectedParses().add(semanticParseInfo); parseResp.getSelectedParses().add(semanticParseInfo);
@@ -62,8 +61,8 @@ public abstract class PluginRecognizer {
return PluginManager.getPluginAgentCanSupport(parseContext); return PluginManager.getPluginAgentCanSupport(parseContext);
} }
protected SemanticParseInfo buildSemanticParseInfo( protected SemanticParseInfo buildSemanticParseInfo(Long dataSetId, ChatPlugin plugin,
Long dataSetId, ChatPlugin plugin, ParseContext parseContext, double distance) { ParseContext parseContext, double distance) {
List<SchemaElementMatch> schemaElementMatches = List<SchemaElementMatch> schemaElementMatches =
parseContext.getMapInfo().getMatchedElements(dataSetId); parseContext.getMapInfo().getMatchedElements(dataSetId);
QueryFilters queryFilters = parseContext.getQueryFilters(); QueryFilters queryFilters = parseContext.getQueryFilters();
@@ -97,21 +96,17 @@ public abstract class PluginRecognizer {
return; return;
} }
schemaElementMatches.stream() schemaElementMatches.stream()
.filter( .filter(schemaElementMatch -> SchemaElementType.VALUE
schemaElementMatch -> .equals(schemaElementMatch.getElement().getType())
SchemaElementType.VALUE.equals( || SchemaElementType.ID.equals(schemaElementMatch.getElement().getType()))
schemaElementMatch.getElement().getType()) .forEach(schemaElementMatch -> {
|| SchemaElementType.ID.equals( QueryFilter queryFilter = new QueryFilter();
schemaElementMatch.getElement().getType())) queryFilter.setValue(schemaElementMatch.getWord());
.forEach( queryFilter.setElementID(schemaElementMatch.getElement().getId());
schemaElementMatch -> { queryFilter.setName(schemaElementMatch.getElement().getName());
QueryFilter queryFilter = new QueryFilter(); queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setValue(schemaElementMatch.getWord()); queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
queryFilter.setElementID(schemaElementMatch.getElement().getId()); semanticParseInfo.getDimensionFilters().add(queryFilter);
queryFilter.setName(schemaElementMatch.getElement().getName()); });
queryFilter.setOperator(FilterOperatorEnum.EQUALS);
queryFilter.setBizName(schemaElementMatch.getElement().getBizName());
semanticParseInfo.getDimensionFilters().add(queryFilter);
});
} }
} }

View File

@@ -53,12 +53,8 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
plugin.setParseMode(ParseMode.EMBEDDING_RECALL); plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double similarity = embeddingRetrieval.getSimilarity(); double similarity = embeddingRetrieval.getSimilarity();
double score = parseContext.getQueryText().length() * similarity; double score = parseContext.getQueryText().length() * similarity;
return PluginRecallResult.builder() return PluginRecallResult.builder().plugin(plugin).dataSetIds(dataSetList)
.plugin(plugin) .score(score).distance(similarity).build();
.dataSetIds(dataSetList)
.score(score)
.distance(similarity)
.build();
} }
} }
return null; return null;
@@ -71,12 +67,9 @@ public class EmbeddingRecallRecognizer extends PluginRecognizer {
List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval(); List<Retrieval> embeddingRetrievals = embeddingResp.getRetrieval();
if (!CollectionUtils.isEmpty(embeddingRetrievals)) { if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals = embeddingRetrievals = embeddingRetrievals.stream()
embeddingRetrievals.stream() .sorted(Comparator.comparingDouble(o -> Math.abs(o.getSimilarity())))
.sorted( .collect(Collectors.toList());
Comparator.comparingDouble(
o -> Math.abs(o.getSimilarity())))
.collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals); embeddingResp.setRetrieval(embeddingRetrievals);
} }
return embeddingRetrievals; return embeddingRetrievals;

View File

@@ -1,4 +1,5 @@
package com.tencent.supersonic.chat.server.processor; package com.tencent.supersonic.chat.server.processor;
/** A ResultProcessor wraps things up before returning results to users. */ /** A ResultProcessor wraps things up before returning results to users. */
public interface ResultProcessor {} public interface ResultProcessor {
}

View File

@@ -52,28 +52,20 @@ public class DimensionRecommendProcessor implements ExecuteResultProcessor {
List<Long> drillDownDimensions = Lists.newArrayList(); List<Long> drillDownDimensions = Lists.newArrayList();
Set<SchemaElement> metricElements = dataSetSchema.getMetrics(); Set<SchemaElement> metricElements = dataSetSchema.getMetrics();
if (!CollectionUtils.isEmpty(metricElements)) { if (!CollectionUtils.isEmpty(metricElements)) {
Optional<SchemaElement> metric = Optional<SchemaElement> metric = metricElements.stream()
metricElements.stream() .filter(schemaElement -> metricId.equals(schemaElement.getId())
.filter( && !CollectionUtils.isEmpty(schemaElement.getRelatedSchemaElements()))
schemaElement -> .findFirst();
metricId.equals(schemaElement.getId())
&& !CollectionUtils.isEmpty(
schemaElement
.getRelatedSchemaElements()))
.findFirst();
if (metric.isPresent()) { if (metric.isPresent()) {
drillDownDimensions = drillDownDimensions = metric.get().getRelatedSchemaElements().stream()
metric.get().getRelatedSchemaElements().stream() .map(RelatedSchemaElement::getDimensionId).collect(Collectors.toList());
.map(RelatedSchemaElement::getDimensionId)
.collect(Collectors.toList());
} }
} }
final List<Long> drillDownDimensionsFinal = drillDownDimensions; final List<Long> drillDownDimensionsFinal = drillDownDimensions;
return dataSetSchema.getDimensions().stream() return dataSetSchema.getDimensions().stream()
.filter(dim -> filterDimension(drillDownDimensionsFinal, dim)) .filter(dim -> filterDimension(drillDownDimensionsFinal, dim))
.sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed())
.limit(recommend_dimension_size) .limit(recommend_dimension_size).collect(Collectors.toList());
.collect(Collectors.toList());
} }
private boolean filterDimension(List<Long> drillDownDimensions, SchemaElement dimension) { private boolean filterDimension(List<Long> drillDownDimensions, SchemaElement dimension) {

View File

@@ -69,19 +69,14 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
queryResult.setAggregateInfo(aggregateInfo); queryResult.setAggregateInfo(aggregateInfo);
} }
public AggregateInfo getAggregateInfo( public AggregateInfo getAggregateInfo(User user, SemanticParseInfo semanticParseInfo,
User user, SemanticParseInfo semanticParseInfo, QueryResult queryResult) { QueryResult queryResult) {
Set<String> resultMetricNames = new HashSet<>(); Set<String> resultMetricNames = new HashSet<>();
queryResult.getQueryColumns().stream() queryResult.getQueryColumns().stream().forEach(
.forEach( c -> resultMetricNames.addAll(SqlSelectHelper.getColumnFromExpr(c.getNameEn())));
c -> Optional<SchemaElement> ratioMetric = semanticParseInfo.getMetrics().stream()
resultMetricNames.addAll( .filter(m -> resultMetricNames.contains(m.getBizName())).findFirst();
SqlSelectHelper.getColumnFromExpr(c.getNameEn())));
Optional<SchemaElement> ratioMetric =
semanticParseInfo.getMetrics().stream()
.filter(m -> resultMetricNames.contains(m.getBizName()))
.findFirst();
AggregateInfo aggregateInfo = new AggregateInfo(); AggregateInfo aggregateInfo = new AggregateInfo();
if (!ratioMetric.isPresent()) { if (!ratioMetric.isPresent()) {
@@ -90,20 +85,15 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
try { try {
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
Optional<String> lastDayOp = Optional<String> lastDayOp = queryResult.getQueryResults().stream()
queryResult.getQueryResults().stream() .filter(r -> r.containsKey(dateField)).map(r -> r.get(dateField).toString())
.filter(r -> r.containsKey(dateField)) .sorted(Comparator.reverseOrder()).findFirst();
.map(r -> r.get(dateField).toString())
.sorted(Comparator.reverseOrder())
.findFirst();
if (!lastDayOp.isPresent()) { if (!lastDayOp.isPresent()) {
return new AggregateInfo(); return new AggregateInfo();
} }
Optional<Map<String, Object>> lastValue = Optional<Map<String, Object>> lastValue = queryResult.getQueryResults().stream()
queryResult.getQueryResults().stream() .filter(r -> r.get(dateField).toString().equals(lastDayOp.get())).findFirst();
.filter(r -> r.get(dateField).toString().equals(lastDayOp.get()))
.findFirst();
MetricInfo metricInfo = new MetricInfo(); MetricInfo metricInfo = new MetricInfo();
metricInfo.setStatistics(new HashMap<>()); metricInfo.setStatistics(new HashMap<>());
@@ -115,23 +105,11 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
metricInfo.setDate(lastValue.get().get(dateField).toString()); metricInfo.setDate(lastValue.get().get(dateField).toString());
CompletableFuture<MetricInfo> metricInfoRoll = CompletableFuture<MetricInfo> metricInfoRoll =
CompletableFuture.supplyAsync( CompletableFuture.supplyAsync(() -> queryRatio(user, semanticParseInfo,
() -> ratioMetric.get(), AggOperatorEnum.RATIO_ROLL, queryResult));
queryRatio(
user,
semanticParseInfo,
ratioMetric.get(),
AggOperatorEnum.RATIO_ROLL,
queryResult));
CompletableFuture<MetricInfo> metricInfoOver = CompletableFuture<MetricInfo> metricInfoOver =
CompletableFuture.supplyAsync( CompletableFuture.supplyAsync(() -> queryRatio(user, semanticParseInfo,
() -> ratioMetric.get(), AggOperatorEnum.RATIO_OVER, queryResult));
queryRatio(
user,
semanticParseInfo,
ratioMetric.get(),
AggOperatorEnum.RATIO_OVER,
queryResult));
CompletableFuture.allOf(metricInfoRoll, metricInfoOver); CompletableFuture.allOf(metricInfoRoll, metricInfoOver);
metricInfo.setName(metricInfoRoll.get().getName()); metricInfo.setName(metricInfoRoll.get().getName());
metricInfo.setValue(metricInfoRoll.get().getValue()); metricInfo.setValue(metricInfoRoll.get().getValue());
@@ -145,19 +123,15 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
} }
@SneakyThrows @SneakyThrows
private MetricInfo queryRatio( private MetricInfo queryRatio(User user, SemanticParseInfo semanticParseInfo,
User user, SchemaElement metric, AggOperatorEnum aggOperatorEnum, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo,
SchemaElement metric,
AggOperatorEnum aggOperatorEnum,
QueryResult queryResult) {
QueryStructReq queryStructReq = QueryStructReq queryStructReq =
QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum); QueryReqBuilder.buildStructRatioReq(semanticParseInfo, metric, aggOperatorEnum);
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField))); queryStructReq.setGroups(new ArrayList<>(Arrays.asList(dateField)));
queryStructReq.setDateInfo( queryStructReq
getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult)); .setDateInfo(getRatioDateConf(aggOperatorEnum, semanticParseInfo, queryResult));
queryStructReq.setConvertToSql(false); queryStructReq.setConvertToSql(false);
SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class); SemanticLayerService queryService = ContextUtils.getBean(SemanticLayerService.class);
SemanticQueryResp queryResp = queryService.queryByReq(queryStructReq, user); SemanticQueryResp queryResp = queryService.queryByReq(queryStructReq, user);
@@ -168,26 +142,22 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
} }
Map<String, Object> result = queryResp.getResultList().get(0); Map<String, Object> result = queryResp.getResultList().get(0);
Optional<QueryColumn> valueColumn = Optional<QueryColumn> valueColumn = queryResp.getColumns().stream()
queryResp.getColumns().stream() .filter(c -> c.getNameEn().equals(metric.getBizName())).findFirst();
.filter(c -> c.getNameEn().equals(metric.getBizName()))
.findFirst();
if (!valueColumn.isPresent()) { if (!valueColumn.isPresent()) {
return metricInfo; return metricInfo;
} }
String valueField = String valueField = String.format("%s_%s", valueColumn.get().getNameEn(),
String.format( aggOperatorEnum.getOperator());
"%s_%s", valueColumn.get().getNameEn(), aggOperatorEnum.getOperator());
if (result.containsKey(valueColumn.get().getNameEn())) { if (result.containsKey(valueColumn.get().getNameEn())) {
DecimalFormat df = new DecimalFormat("#.####"); DecimalFormat df = new DecimalFormat("#.####");
metricInfo.setValue(df.format(result.get(valueColumn.get().getNameEn()))); metricInfo.setValue(df.format(result.get(valueColumn.get().getNameEn())));
} }
String ratio = ""; String ratio = "";
if (Objects.nonNull(result.get(valueField))) { if (Objects.nonNull(result.get(valueField))) {
ratio = ratio = String.format("%.2f", (Double.valueOf(result.get(valueField).toString()) * 100))
String.format("%.2f", (Double.valueOf(result.get(valueField).toString()) * 100)) + "%";
+ "%";
} }
String statisticsRollName = RatioOverType.DAY_ON_DAY.getShowName(); String statisticsRollName = RatioOverType.DAY_ON_DAY.getShowName();
String statisticsOverName = RatioOverType.WEEK_ON_DAY.getShowName(); String statisticsOverName = RatioOverType.WEEK_ON_DAY.getShowName();
@@ -199,28 +169,20 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
statisticsRollName = RatioOverType.WEEK_ON_WEEK.getShowName(); statisticsRollName = RatioOverType.WEEK_ON_WEEK.getShowName();
statisticsOverName = RatioOverType.MONTH_ON_WEEK.getShowName(); statisticsOverName = RatioOverType.MONTH_ON_WEEK.getShowName();
} }
metricInfo metricInfo.getStatistics()
.getStatistics() .put(aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? statisticsRollName
.put( : statisticsOverName, ratio);
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
? statisticsRollName
: statisticsOverName,
ratio);
metricInfo.setName(metric.getName()); metricInfo.setName(metric.getName());
return metricInfo; return metricInfo;
} }
private DateConf getRatioDateConf( private DateConf getRatioDateConf(AggOperatorEnum aggOperatorEnum,
AggOperatorEnum aggOperatorEnum, SemanticParseInfo semanticParseInfo, QueryResult queryResult) {
SemanticParseInfo semanticParseInfo,
QueryResult queryResult) {
String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo()); String dateField = QueryReqBuilder.getDateField(semanticParseInfo.getDateInfo());
Optional<String> lastDayOp = Optional<String> lastDayOp =
queryResult.getQueryResults().stream() queryResult.getQueryResults().stream().map(r -> r.get(dateField).toString())
.map(r -> r.get(dateField).toString()) .sorted(Comparator.reverseOrder()).findFirst();
.sorted(Comparator.reverseOrder())
.findFirst();
if (!lastDayOp.isPresent()) { if (!lastDayOp.isPresent()) {
return semanticParseInfo.getDateInfo(); return semanticParseInfo.getDateInfo();
@@ -236,31 +198,25 @@ public class MetricRatioProcessor implements ExecuteResultProcessor {
DateTimeFormatter formatter = DateTimeFormatter formatter =
DateUtils.getDateFormatter(lastDay, new String[] {DAY_FORMAT, DAY_FORMAT_INT}); DateUtils.getDateFormatter(lastDay, new String[] {DAY_FORMAT, DAY_FORMAT_INT});
LocalDate end = LocalDate.parse(lastDay, formatter); LocalDate end = LocalDate.parse(lastDay, formatter);
start = start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? end.minusDays(1).format(formatter)
? end.minusDays(1).format(formatter) : end.minusWeeks(1).format(formatter);
: end.minusWeeks(1).format(formatter);
} }
if (DatePeriodEnum.WEEK.equals(semanticParseInfo.getDateInfo().getPeriod())) { if (DatePeriodEnum.WEEK.equals(semanticParseInfo.getDateInfo().getPeriod())) {
DateTimeFormatter formatter = DateTimeFormatter formatter = DateUtils.getTimeFormatter(lastDay,
DateUtils.getTimeFormatter( new String[] {TIMES_FORMAT, DAY_FORMAT, TIME_FORMAT, DAY_FORMAT_INT});
lastDay,
new String[] {TIMES_FORMAT, DAY_FORMAT, TIME_FORMAT, DAY_FORMAT_INT});
LocalDateTime end = LocalDateTime.parse(lastDay, formatter); LocalDateTime end = LocalDateTime.parse(lastDay, formatter);
start = start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? end.minusWeeks(1).format(formatter)
? end.minusWeeks(1).format(formatter) : end.minusMonths(1).with(DayOfWeek.MONDAY).format(formatter);
: end.minusMonths(1).with(DayOfWeek.MONDAY).format(formatter);
} }
if (DatePeriodEnum.MONTH.equals(semanticParseInfo.getDateInfo().getPeriod())) { if (DatePeriodEnum.MONTH.equals(semanticParseInfo.getDateInfo().getPeriod())) {
DateTimeFormatter formatter = DateTimeFormatter formatter = DateUtils.getDateFormatter(lastDay,
DateUtils.getDateFormatter( new String[] {MONTH_FORMAT, MONTH_FORMAT_INT});
lastDay, new String[] {MONTH_FORMAT, MONTH_FORMAT_INT});
YearMonth end = YearMonth.parse(lastDay, formatter); YearMonth end = YearMonth.parse(lastDay, formatter);
start = start = aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL)
aggOperatorEnum.equals(AggOperatorEnum.RATIO_ROLL) ? end.minusMonths(1).format(formatter)
? end.minusMonths(1).format(formatter) : end.minusYears(1).format(formatter);
: end.minusYears(1).format(formatter);
} }
dayList.add(start); dayList.add(start);
dateConf.setDateList(dayList); dateConf.setDateList(dayList);

View File

@@ -45,33 +45,24 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
List<String> metricNames = List<String> metricNames =
Collections.singletonList(parseInfo.getMetrics().iterator().next().getName()); Collections.singletonList(parseInfo.getMetrics().iterator().next().getName());
Map<String, Object> filterCondition = new HashMap<>(); Map<String, Object> filterCondition = new HashMap<>();
filterCondition.put( filterCondition.put("modelId",
"modelId", parseInfo.getMetrics().iterator().next().getDataSetId().toString()); parseInfo.getMetrics().iterator().next().getDataSetId().toString());
filterCondition.put("type", SchemaElementType.METRIC.name()); filterCondition.put("type", SchemaElementType.METRIC.name());
RetrieveQuery retrieveQuery = RetrieveQuery retrieveQuery = RetrieveQuery.builder().queryTextsList(metricNames)
RetrieveQuery.builder() .filterCondition(filterCondition).queryEmbeddings(null).build();
.queryTextsList(metricNames)
.filterCondition(filterCondition)
.queryEmbeddings(null)
.build();
MetaEmbeddingService metaEmbeddingService = MetaEmbeddingService metaEmbeddingService =
ContextUtils.getBean(MetaEmbeddingService.class); ContextUtils.getBean(MetaEmbeddingService.class);
List<RetrieveQueryResult> retrieveQueryResults = List<RetrieveQueryResult> retrieveQueryResults = metaEmbeddingService.retrieveQuery(
metaEmbeddingService.retrieveQuery( retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(), new HashSet<>());
retrieveQuery, METRIC_RECOMMEND_SIZE + 1, new HashMap<>(), new HashSet<>());
if (CollectionUtils.isEmpty(retrieveQueryResults)) { if (CollectionUtils.isEmpty(retrieveQueryResults)) {
return; return;
} }
List<Retrieval> retrievals = List<Retrieval> retrievals = retrieveQueryResults.stream()
retrieveQueryResults.stream() .flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream()) .sorted(Comparator.comparingDouble(Retrieval::getSimilarity)).distinct()
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity)) .collect(Collectors.toList());
.distinct() Set<Long> metricIds = parseInfo.getMetrics().stream().map(SchemaElement::getId)
.collect(Collectors.toList()); .collect(Collectors.toSet());
Set<Long> metricIds =
parseInfo.getMetrics().stream()
.map(SchemaElement::getId)
.collect(Collectors.toSet());
int metricOrder = 0; int metricOrder = 0;
for (SchemaElement metric : parseInfo.getMetrics()) { for (SchemaElement metric : parseInfo.getMetrics()) {
metric.setOrder(metricOrder++); metric.setOrder(metricOrder++);
@@ -79,23 +70,15 @@ public class MetricRecommendProcessor implements ExecuteResultProcessor {
for (Retrieval retrieval : retrievals) { for (Retrieval retrieval : retrievals) {
if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) { if (!metricIds.contains(Retrieval.getLongId(retrieval.getId()))) {
if (Objects.nonNull(retrieval.getMetadata().get("id"))) { if (Objects.nonNull(retrieval.getMetadata().get("id"))) {
String idStr = String idStr = retrieval.getMetadata().get("id").toString()
retrieval .replaceAll(DictWordType.NATURE_SPILT, "");
.getMetadata()
.get("id")
.toString()
.replaceAll(DictWordType.NATURE_SPILT, "");
retrieval.getMetadata().put("id", idStr); retrieval.getMetadata().put("id", idStr);
} }
String metaStr = JSONObject.toJSONString(retrieval.getMetadata()); String metaStr = JSONObject.toJSONString(retrieval.getMetadata());
SchemaElement schemaElement = JSONObject.parseObject(metaStr, SchemaElement.class); SchemaElement schemaElement = JSONObject.parseObject(metaStr, SchemaElement.class);
if (retrieval.getMetadata().containsKey("dataSetId")) { if (retrieval.getMetadata().containsKey("dataSetId")) {
String dataSetId = String dataSetId = retrieval.getMetadata().get("dataSetId").toString()
retrieval .replace(Constants.UNDERLINE, "");
.getMetadata()
.get("dataSetId")
.toString()
.replace(Constants.UNDERLINE, "");
schemaElement.setDataSetId(Long.parseLong(dataSetId)); schemaElement.setDataSetId(Long.parseLong(dataSetId));
} }
schemaElement.setOrder(++metricOrder); schemaElement.setOrder(++metricOrder);

View File

@@ -43,13 +43,8 @@ public class QueryRecommendProcessor implements ParseResultProcessor {
String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId); String memoryCollectionName = embeddingConfig.getMemoryCollectionName(agentId);
List<Text2SQLExemplar> exemplars = List<Text2SQLExemplar> exemplars =
exemplarService.recallExemplars(memoryCollectionName, queryText, 5); exemplarService.recallExemplars(memoryCollectionName, queryText, 5);
return exemplars.stream() return exemplars.stream().map(sqlExemplar -> SimilarQueryRecallResp.builder()
.map( .queryText(sqlExemplar.getQuestion()).build()).collect(Collectors.toList());
sqlExemplar ->
SimilarQueryRecallResp.builder()
.queryText(sqlExemplar.getQuestion())
.build())
.collect(Collectors.toList());
} }
private ChatQueryDO getChatQuery(Long queryId) { private ChatQueryDO getChatQuery(Long queryId) {

View File

@@ -11,11 +11,7 @@ public class TimeCostProcessor implements ParseResultProcessor {
@Override @Override
public void process(ParseContext parseContext, ParseResp parseResp) { public void process(ParseContext parseContext, ParseResp parseResp) {
long parseStartTime = parseResp.getParseTimeCost().getParseStartTime(); long parseStartTime = parseResp.getParseTimeCost().getParseStartTime();
parseResp parseResp.getParseTimeCost().setParseTime(System.currentTimeMillis() - parseStartTime
.getParseTimeCost() - parseResp.getParseTimeCost().getSqlTime());
.setParseTime(
System.currentTimeMillis()
- parseStartTime
- parseResp.getParseTimeCost().getSqlTime());
} }
} }

View File

@@ -26,21 +26,18 @@ import java.util.Map;
@RequestMapping({"/api/chat/agent", "/openapi/chat/agent"}) @RequestMapping({"/api/chat/agent", "/openapi/chat/agent"})
public class AgentController { public class AgentController {
@Autowired private AgentService agentService; @Autowired
private AgentService agentService;
@PostMapping @PostMapping
public Agent createAgent( public Agent createAgent(@RequestBody Agent agent, HttpServletRequest httpServletRequest,
@RequestBody Agent agent,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.createAgent(agent, user); return agentService.createAgent(agent, user);
} }
@PutMapping @PutMapping
public Agent updateAgent( public Agent updateAgent(@RequestBody Agent agent, HttpServletRequest httpServletRequest,
@RequestBody Agent agent,
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) { HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return agentService.updateAgent(agent, user); return agentService.updateAgent(agent, user);

View File

@@ -29,33 +29,29 @@ import java.util.List;
@RequestMapping({"/api/chat/conf", "/openapi/chat/conf"}) @RequestMapping({"/api/chat/conf", "/openapi/chat/conf"})
public class ChatConfigController { public class ChatConfigController {
@Autowired private ConfigService configService; @Autowired
private ConfigService configService;
@Autowired private SemanticLayerService semanticLayerService; @Autowired
private SemanticLayerService semanticLayerService;
@PostMapping @PostMapping
public Long addChatConfig( public Long addChatConfig(@RequestBody ChatConfigBaseReq extendBaseCmd,
@RequestBody ChatConfigBaseReq extendBaseCmd, HttpServletRequest request, HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
return configService.addConfig(extendBaseCmd, user); return configService.addConfig(extendBaseCmd, user);
} }
@PutMapping @PutMapping
public Long editModelExtend( public Long editModelExtend(@RequestBody ChatConfigEditReqReq extendEditCmd,
@RequestBody ChatConfigEditReqReq extendEditCmd, HttpServletRequest request, HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
return configService.editConfig(extendEditCmd, user); return configService.editConfig(extendEditCmd, user);
} }
@PostMapping("/search") @PostMapping("/search")
public List<ChatConfigResp> search( public List<ChatConfigResp> search(@RequestBody ChatConfigFilter filter,
@RequestBody ChatConfigFilter filter, HttpServletRequest request, HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
return configService.search(filter, user); return configService.search(filter, user);
} }

View File

@@ -25,14 +25,13 @@ import java.util.List;
@RequestMapping({"/api/chat/manage", "/openapi/chat/manage"}) @RequestMapping({"/api/chat/manage", "/openapi/chat/manage"})
public class ChatController { public class ChatController {
@Autowired private ChatManageService chatService; @Autowired
private ChatManageService chatService;
@PostMapping("/save") @PostMapping("/save")
public Boolean save( public Boolean save(@RequestParam(value = "chatName") String chatName,
@RequestParam(value = "chatName") String chatName,
@RequestParam(value = "agentId", required = false) Integer agentId, @RequestParam(value = "agentId", required = false) Integer agentId,
HttpServletRequest request, HttpServletRequest request, HttpServletResponse response) {
HttpServletResponse response) {
chatService.addChat(UserHolder.findUser(request, response), chatName, agentId); chatService.addChat(UserHolder.findUser(request, response), chatName, agentId);
return true; return true;
} }
@@ -40,50 +39,42 @@ public class ChatController {
@GetMapping("/getAll") @GetMapping("/getAll")
public List<ChatDO> getAllConversions( public List<ChatDO> getAllConversions(
@RequestParam(value = "agentId", required = false) Integer agentId, @RequestParam(value = "agentId", required = false) Integer agentId,
HttpServletRequest request, HttpServletRequest request, HttpServletResponse response) {
HttpServletResponse response) {
String userName = UserHolder.findUser(request, response).getName(); String userName = UserHolder.findUser(request, response).getName();
return chatService.getAll(userName, agentId); return chatService.getAll(userName, agentId);
} }
@PostMapping("/delete") @PostMapping("/delete")
public Boolean deleteConversion( public Boolean deleteConversion(@RequestParam(value = "chatId") long chatId,
@RequestParam(value = "chatId") long chatId, HttpServletRequest request, HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
String userName = UserHolder.findUser(request, response).getName(); String userName = UserHolder.findUser(request, response).getName();
return chatService.deleteChat(chatId, userName); return chatService.deleteChat(chatId, userName);
} }
@PostMapping("/updateChatName") @PostMapping("/updateChatName")
public Boolean updateConversionName( public Boolean updateConversionName(@RequestParam(value = "chatId") Long chatId,
@RequestParam(value = "chatId") Long chatId, @RequestParam(value = "chatName") String chatName, HttpServletRequest request,
@RequestParam(value = "chatName") String chatName,
HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
String userName = UserHolder.findUser(request, response).getName(); String userName = UserHolder.findUser(request, response).getName();
return chatService.updateChatName(chatId, chatName, userName); return chatService.updateChatName(chatId, chatName, userName);
} }
@PostMapping("/updateQAFeedback") @PostMapping("/updateQAFeedback")
public Boolean updateQAFeedback( public Boolean updateQAFeedback(@RequestParam(value = "id") Integer id,
@RequestParam(value = "id") Integer id,
@RequestParam(value = "score") Integer score, @RequestParam(value = "score") Integer score,
@RequestParam(value = "feedback", required = false) String feedback) { @RequestParam(value = "feedback", required = false) String feedback) {
return chatService.updateFeedback(id, score, feedback); return chatService.updateFeedback(id, score, feedback);
} }
@PostMapping("/updateChatIsTop") @PostMapping("/updateChatIsTop")
public Boolean updateConversionIsTop( public Boolean updateConversionIsTop(@RequestParam(value = "chatId") Long chatId,
@RequestParam(value = "chatId") Long chatId, @RequestParam(value = "isTop") int isTop) { @RequestParam(value = "isTop") int isTop) {
return chatService.updateChatIsTop(chatId, isTop); return chatService.updateChatIsTop(chatId, isTop);
} }
@PostMapping("/pageQueryInfo") @PostMapping("/pageQueryInfo")
public PageInfo<QueryResp> pageQueryInfo( public PageInfo<QueryResp> pageQueryInfo(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestBody PageQueryInfoReq pageQueryInfoCommand, @RequestParam(value = "chatId") long chatId, HttpServletRequest request,
@RequestParam(value = "chatId") long chatId,
HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
pageQueryInfoCommand.setUserName(UserHolder.findUser(request, response).getName()); pageQueryInfoCommand.setUserName(UserHolder.findUser(request, response).getName());
return chatService.queryInfo(pageQueryInfoCommand, chatId); return chatService.queryInfo(pageQueryInfoCommand, chatId);
@@ -95,8 +86,7 @@ public class ChatController {
} }
@PostMapping("/queryShowCase") @PostMapping("/queryShowCase")
public ShowCaseResp queryShowCase( public ShowCaseResp queryShowCase(@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestBody PageQueryInfoReq pageQueryInfoCommand,
@RequestParam(value = "agentId") int agentId) { @RequestParam(value = "agentId") int agentId) {
return chatService.queryShowCase(pageQueryInfoCommand, agentId); return chatService.queryShowCase(pageQueryInfoCommand, agentId);
} }

View File

@@ -27,43 +27,33 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping({"/api/chat/query", "/openapi/chat/query"}) @RequestMapping({"/api/chat/query", "/openapi/chat/query"})
public class ChatQueryController { public class ChatQueryController {
@Autowired private ChatQueryService chatQueryService; @Autowired
private ChatQueryService chatQueryService;
@PostMapping("search") @PostMapping("search")
public Object search( public Object search(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
@RequestBody ChatParseReq chatParseReq,
HttpServletRequest request,
HttpServletResponse response) { HttpServletResponse response) {
chatParseReq.setUser(UserHolder.findUser(request, response)); chatParseReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.search(chatParseReq); return chatQueryService.search(chatParseReq);
} }
@PostMapping("parse") @PostMapping("parse")
public Object parse( public Object parse(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
@RequestBody ChatParseReq chatParseReq, HttpServletResponse response) throws Exception {
HttpServletRequest request,
HttpServletResponse response)
throws Exception {
chatParseReq.setUser(UserHolder.findUser(request, response)); chatParseReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performParsing(chatParseReq); return chatQueryService.performParsing(chatParseReq);
} }
@PostMapping("execute") @PostMapping("execute")
public Object execute( public Object execute(@RequestBody ChatExecuteReq chatExecuteReq, HttpServletRequest request,
@RequestBody ChatExecuteReq chatExecuteReq, HttpServletResponse response) throws Exception {
HttpServletRequest request,
HttpServletResponse response)
throws Exception {
chatExecuteReq.setUser(UserHolder.findUser(request, response)); chatExecuteReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.performExecution(chatExecuteReq); return chatQueryService.performExecution(chatExecuteReq);
} }
@PostMapping("/") @PostMapping("/")
public Object query( public Object query(@RequestBody ChatParseReq chatParseReq, HttpServletRequest request,
@RequestBody ChatParseReq chatParseReq, HttpServletResponse response) throws Exception {
HttpServletRequest request,
HttpServletResponse response)
throws Exception {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
chatParseReq.setUser(user); chatParseReq.setUser(user);
ParseResp parseResp = chatQueryService.performParsing(chatParseReq); ParseResp parseResp = chatQueryService.performParsing(chatParseReq);
@@ -80,22 +70,16 @@ public class ChatQueryController {
} }
@PostMapping("queryData") @PostMapping("queryData")
public Object queryData( public Object queryData(@RequestBody ChatQueryDataReq chatQueryDataReq,
@RequestBody ChatQueryDataReq chatQueryDataReq, HttpServletRequest request, HttpServletResponse response) throws Exception {
HttpServletRequest request,
HttpServletResponse response)
throws Exception {
chatQueryDataReq.setUser(UserHolder.findUser(request, response)); chatQueryDataReq.setUser(UserHolder.findUser(request, response));
return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response)); return chatQueryService.queryData(chatQueryDataReq, UserHolder.findUser(request, response));
} }
@PostMapping("queryDimensionValue") @PostMapping("queryDimensionValue")
public Object queryDimensionValue( public Object queryDimensionValue(@RequestBody @Valid DimensionValueReq dimensionValueReq,
@RequestBody @Valid DimensionValueReq dimensionValueReq, HttpServletRequest request, HttpServletResponse response) throws Exception {
HttpServletRequest request, return chatQueryService.queryDimensionValue(dimensionValueReq,
HttpServletResponse response) UserHolder.findUser(request, response));
throws Exception {
return chatQueryService.queryDimensionValue(
dimensionValueReq, UserHolder.findUser(request, response));
} }
} }

View File

@@ -21,13 +21,12 @@ import org.springframework.web.bind.annotation.RestController;
@RequestMapping({"/api/chat/memory"}) @RequestMapping({"/api/chat/memory"})
public class MemoryController { public class MemoryController {
@Autowired private MemoryService memoryService; @Autowired
private MemoryService memoryService;
@PostMapping("/updateMemory") @PostMapping("/updateMemory")
public Boolean updateMemory( public Boolean updateMemory(@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq,
@RequestBody ChatMemoryUpdateReq chatMemoryUpdateReq, HttpServletRequest request, HttpServletResponse response) {
HttpServletRequest request,
HttpServletResponse response) {
User user = UserHolder.findUser(request, response); User user = UserHolder.findUser(request, response);
memoryService.updateMemory(chatMemoryUpdateReq, user); memoryService.updateMemory(chatMemoryUpdateReq, user);
return true; return true;

View File

@@ -25,23 +25,20 @@ import java.util.List;
@RequestMapping("/api/chat/plugin") @RequestMapping("/api/chat/plugin")
public class PluginController { public class PluginController {
@Autowired protected PluginService pluginService; @Autowired
protected PluginService pluginService;
@PostMapping @PostMapping
public boolean createPlugin( public boolean createPlugin(@RequestBody ChatPlugin plugin,
@RequestBody ChatPlugin plugin, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
pluginService.createPlugin(plugin, user); pluginService.createPlugin(plugin, user);
return true; return true;
} }
@PutMapping @PutMapping
public boolean updatePlugin( public boolean updatePlugin(@RequestBody ChatPlugin plugin,
@RequestBody ChatPlugin plugin, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
pluginService.updatePlugin(plugin, user); pluginService.updatePlugin(plugin, user);
return true; return true;
@@ -59,18 +56,16 @@ public class PluginController {
} }
@PostMapping("/query") @PostMapping("/query")
List<ChatPlugin> query( List<ChatPlugin> query(@RequestBody PluginQueryReq pluginQueryReq,
@RequestBody PluginQueryReq pluginQueryReq, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
HttpServletRequest httpServletRequest,
HttpServletResponse httpServletResponse) {
User user = UserHolder.findUser(httpServletRequest, httpServletResponse); User user = UserHolder.findUser(httpServletRequest, httpServletResponse);
return pluginService.queryWithAuthCheck(pluginQueryReq, user); return pluginService.queryWithAuthCheck(pluginQueryReq, user);
} }
@AuthenticationIgnore @AuthenticationIgnore
@PostMapping("/pluginDemo") @PostMapping("/pluginDemo")
public String pluginDemo( public String pluginDemo(@RequestParam("queryText") String queryText,
@RequestParam("queryText") String queryText, @RequestBody Object object) { @RequestBody Object object) {
return String.format("已收到您的问题:%s, 但这只是一个demo~", queryText); return String.format("已收到您的问题:%s, 但这只是一个demo~", queryText);
} }
} }

View File

@@ -33,9 +33,11 @@ import java.util.stream.Collectors;
@Service @Service
public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implements AgentService { public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implements AgentService {
@Autowired private MemoryService memoryService; @Autowired
private MemoryService memoryService;
@Autowired private ChatQueryService chatQueryService; @Autowired
private ChatQueryService chatQueryService;
private ExecutorService executorService = Executors.newFixedThreadPool(1); private ExecutorService executorService = Executors.newFixedThreadPool(1);
@@ -98,8 +100,7 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
} }
private synchronized void doExecuteAgentExamples(Agent agent) { private synchronized void doExecuteAgentExamples(Agent agent) {
if (!agent.containsLLMTool() if (!agent.containsLLMTool() || !LLMConnHelper.testConnection(agent.getModelConfig())
|| !LLMConnHelper.testConnection(agent.getModelConfig())
|| CollectionUtils.isEmpty(agent.getExamples())) { || CollectionUtils.isEmpty(agent.getExamples())) {
return; return;
} }
@@ -107,10 +108,8 @@ public class AgentServiceImpl extends ServiceImpl<AgentDOMapper, AgentDO> implem
List<String> examples = agent.getExamples(); List<String> examples = agent.getExamples();
ChatMemoryFilter chatMemoryFilter = ChatMemoryFilter chatMemoryFilter =
ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build(); ChatMemoryFilter.builder().agentId(agent.getId()).questions(examples).build();
List<String> memoriesExisted = List<String> memoriesExisted = memoryService.getMemories(chatMemoryFilter).stream()
memoryService.getMemories(chatMemoryFilter).stream() .map(ChatMemoryDO::getQuestion).collect(Collectors.toList());
.map(ChatMemoryDO::getQuestion)
.collect(Collectors.toList());
for (String example : examples) { for (String example : examples) {
if (memoriesExisted.contains(example)) { if (memoriesExisted.contains(example)) {
continue; continue;

View File

@@ -37,8 +37,10 @@ import java.util.stream.Collectors;
@Service @Service
public class ChatManageServiceImpl implements ChatManageService { public class ChatManageServiceImpl implements ChatManageService {
@Autowired private ChatRepository chatRepository; @Autowired
@Autowired private ChatQueryRepository chatQueryRepository; private ChatRepository chatRepository;
@Autowired
private ChatQueryRepository chatQueryRepository;
@Override @Override
public Long addChat(User user, String chatName, Integer agentId) { public Long addChat(User user, String chatName, Integer agentId) {
@@ -121,30 +123,23 @@ public class ChatManageServiceImpl implements ChatManageService {
if (CollectionUtils.isEmpty(queryResps)) { if (CollectionUtils.isEmpty(queryResps)) {
return showCaseResp; return showCaseResp;
} }
queryResps.removeIf( queryResps.removeIf(queryResp -> {
queryResp -> { if (queryResp.getQueryResult() == null) {
if (queryResp.getQueryResult() == null) { return true;
return true; }
} if (queryResp.getQueryResult().getResponse() != null) {
if (queryResp.getQueryResult().getResponse() != null) { return false;
return false; }
} if (CollectionUtils.isEmpty(queryResp.getQueryResult().getQueryResults())) {
if (CollectionUtils.isEmpty(queryResp.getQueryResult().getQueryResults())) { return true;
return true; }
} Map<String, Object> data = queryResp.getQueryResult().getQueryResults().get(0);
Map<String, Object> data = queryResp.getQueryResult().getQueryResults().get(0); return CollectionUtils.isEmpty(data);
return CollectionUtils.isEmpty(data); });
}); queryResps = new ArrayList<>(queryResps.stream()
queryResps = .collect(Collectors.toMap(QueryResp::getQueryText, Function.identity(),
new ArrayList<>( (existing, replacement) -> existing, LinkedHashMap::new))
queryResps.stream() .values());
.collect(
Collectors.toMap(
QueryResp::getQueryText,
Function.identity(),
(existing, replacement) -> existing,
LinkedHashMap::new))
.values());
fillParseInfo(queryResps); fillParseInfo(queryResps);
Map<Long, List<QueryResp>> showCaseMap = Map<Long, List<QueryResp>> showCaseMap =
queryResps.stream().collect(Collectors.groupingBy(QueryResp::getChatId)); queryResps.stream().collect(Collectors.groupingBy(QueryResp::getChatId));
@@ -166,17 +161,11 @@ public class ChatManageServiceImpl implements ChatManageService {
if (CollectionUtils.isEmpty(chatParseDOList)) { if (CollectionUtils.isEmpty(chatParseDOList)) {
continue; continue;
} }
List<SemanticParseInfo> parseInfos = List<SemanticParseInfo> parseInfos = chatParseDOList.stream()
chatParseDOList.stream() .map(chatParseDO -> JsonUtil.toObject(chatParseDO.getParseInfo(),
.map( SemanticParseInfo.class))
chatParseDO -> .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed())
JsonUtil.toObject( .collect(Collectors.toList());
chatParseDO.getParseInfo(),
SemanticParseInfo.class))
.sorted(
Comparator.comparingDouble(SemanticParseInfo::getScore)
.reversed())
.collect(Collectors.toList());
queryResp.setParseInfos(parseInfos); queryResp.setParseInfos(parseInfos);
} }
} }
@@ -188,10 +177,8 @@ public class ChatManageServiceImpl implements ChatManageService {
chatQueryDO.setQueryResult(JsonUtil.toString(queryResult)); chatQueryDO.setQueryResult(JsonUtil.toString(queryResult));
chatQueryDO.setQueryState(1); chatQueryDO.setQueryState(1);
updateQuery(chatQueryDO); updateQuery(chatQueryDO);
chatRepository.updateLastQuestion( chatRepository.updateLastQuestion(chatExecuteReq.getChatId().longValue(),
chatExecuteReq.getChatId().longValue(), chatExecuteReq.getQueryText(), getCurrentTime());
chatExecuteReq.getQueryText(),
getCurrentTime());
return chatQueryDO; return chatQueryDO;
} }

View File

@@ -78,10 +78,14 @@ import java.util.stream.Collectors;
@Service @Service
public class ChatQueryServiceImpl implements ChatQueryService { public class ChatQueryServiceImpl implements ChatQueryService {
@Autowired private ChatManageService chatManageService; @Autowired
@Autowired private ChatLayerService chatLayerService; private ChatManageService chatManageService;
@Autowired private SemanticLayerService semanticLayerService; @Autowired
@Autowired private AgentService agentService; private ChatLayerService chatLayerService;
@Autowired
private SemanticLayerService semanticLayerService;
@Autowired
private AgentService agentService;
private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers(); private List<ChatQueryParser> chatQueryParsers = ComponentFactory.getChatParsers();
private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors(); private List<ChatQueryExecutor> chatQueryExecutors = ComponentFactory.getChatExecutors();
@@ -149,11 +153,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
chatParseReq.setUser(User.getFakeUser()); chatParseReq.setUser(User.getFakeUser());
ParseResp parseResp = performParsing(chatParseReq); ParseResp parseResp = performParsing(chatParseReq);
if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) { if (CollectionUtils.isEmpty(parseResp.getSelectedParses())) {
log.debug( log.debug("chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty",
"chatId:{}, agentId:{}, queryText:{}, parseResp.getSelectedParses() is empty", chatId, agentId, queryText);
chatId,
agentId,
queryText);
return null; return null;
} }
ChatExecuteReq executeReq = new ChatExecuteReq(); ChatExecuteReq executeReq = new ChatExecuteReq();
@@ -184,9 +185,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) { private ExecuteContext buildExecuteContext(ChatExecuteReq chatExecuteReq) {
ExecuteContext executeContext = new ExecuteContext(); ExecuteContext executeContext = new ExecuteContext();
BeanMapper.mapper(chatExecuteReq, executeContext); BeanMapper.mapper(chatExecuteReq, executeContext);
SemanticParseInfo parseInfo = SemanticParseInfo parseInfo = chatManageService.getParseInfo(chatExecuteReq.getQueryId(),
chatManageService.getParseInfo( chatExecuteReq.getParseId());
chatExecuteReq.getQueryId(), chatExecuteReq.getParseId());
Agent agent = agentService.getAgent(chatExecuteReq.getAgentId()); Agent agent = agentService.getAgent(chatExecuteReq.getAgentId());
executeContext.setAgent(agent); executeContext.setAgent(agent);
executeContext.setParseInfo(parseInfo); executeContext.setParseInfo(parseInfo);
@@ -222,12 +222,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL()); return SqlSelectHelper.getAllSelectFields(sqlInfo.getCorrectedS2SQL());
} }
private void handleLLMQueryMode( private void handleLLMQueryMode(ChatQueryDataReq chatQueryDataReq, SemanticQuery semanticQuery,
ChatQueryDataReq chatQueryDataReq, DataSetSchema dataSetSchema, User user) throws Exception {
SemanticQuery semanticQuery,
DataSetSchema dataSetSchema,
User user)
throws Exception {
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
List<String> fields = getFieldsFromSql(parseInfo); List<String> fields = getFieldsFromSql(parseInfo);
if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) { if (checkMetricReplace(fields, chatQueryDataReq.getMetrics())) {
@@ -245,16 +241,16 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
} }
private void handleRuleQueryMode( private void handleRuleQueryMode(SemanticQuery semanticQuery, DataSetSchema dataSetSchema,
SemanticQuery semanticQuery, DataSetSchema dataSetSchema, User user) { User user) {
log.info("rule begin replace metrics and revise filters!"); log.info("rule begin replace metrics and revise filters!");
validFilter(semanticQuery.getParseInfo().getDimensionFilters()); validFilter(semanticQuery.getParseInfo().getDimensionFilters());
validFilter(semanticQuery.getParseInfo().getMetricFilters()); validFilter(semanticQuery.getParseInfo().getMetricFilters());
semanticQuery.initS2Sql(dataSetSchema, user); semanticQuery.initS2Sql(dataSetSchema, user);
} }
private QueryResult executeQuery( private QueryResult executeQuery(SemanticQuery semanticQuery, User user,
SemanticQuery semanticQuery, User user, DataSetSchema dataSetSchema) throws Exception { DataSetSchema dataSetSchema) throws Exception {
SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq(); SemanticQueryReq semanticQueryReq = semanticQuery.buildSemanticQueryReq();
SemanticParseInfo parseInfo = semanticQuery.getParseInfo(); SemanticParseInfo parseInfo = semanticQuery.getParseInfo();
QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user); QueryResult queryResult = doExecution(semanticQueryReq, parseInfo.getQueryMode(), user);
@@ -275,8 +271,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return !oriFields.containsAll(metricNames); return !oriFields.containsAll(metricNames);
} }
private String reviseCorrectS2SQL( private String reviseCorrectS2SQL(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
ChatQueryDataReq queryData, SemanticParseInfo parseInfo, DataSetSchema dataSetSchema) { DataSetSchema dataSetSchema) {
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("correctorSql before replacing:{}", correctorSql); log.info("correctorSql before replacing:{}", correctorSql);
// get where filter and having filter // get where filter and having filter
@@ -286,21 +282,12 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// replace where filter // replace where filter
List<Expression> addWhereConditions = new ArrayList<>(); List<Expression> addWhereConditions = new ArrayList<>();
Set<String> removeWhereFieldNames = Set<String> removeWhereFieldNames =
updateFilters( updateFilters(whereExpressionList, queryData.getDimensionFilters(),
whereExpressionList, parseInfo.getDimensionFilters(), addWhereConditions);
queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(),
addWhereConditions);
Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>(); Map<String, Map<String, String>> filedNameToValueMap = new HashMap<>();
Set<String> removeDataFieldNames = Set<String> removeDataFieldNames = updateDateInfo(queryData, parseInfo, dataSetSchema,
updateDateInfo( filedNameToValueMap, whereExpressionList, addWhereConditions);
queryData,
parseInfo,
dataSetSchema,
filedNameToValueMap,
whereExpressionList,
addWhereConditions);
removeWhereFieldNames.addAll(removeDataFieldNames); removeWhereFieldNames.addAll(removeDataFieldNames);
correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap); correctorSql = SqlReplaceHelper.replaceValue(correctorSql, filedNameToValueMap);
@@ -311,11 +298,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
SqlSelectHelper.getHavingExpressions(correctorSql); SqlSelectHelper.getHavingExpressions(correctorSql);
List<Expression> addHavingConditions = new ArrayList<>(); List<Expression> addHavingConditions = new ArrayList<>();
Set<String> removeHavingFieldNames = Set<String> removeHavingFieldNames =
updateFilters( updateFilters(havingExpressionList, queryData.getDimensionFilters(),
havingExpressionList, parseInfo.getDimensionFilters(), addHavingConditions);
queryData.getDimensionFilters(),
parseInfo.getDimensionFilters(),
addHavingConditions);
correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>()); correctorSql = SqlReplaceHelper.replaceHavingValue(correctorSql, new HashMap<>());
correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames); correctorSql = SqlRemoveHelper.removeHavingCondition(correctorSql, removeHavingFieldNames);
@@ -326,10 +310,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) { private void replaceMetrics(SemanticParseInfo parseInfo, SchemaElement metric) {
List<String> oriMetrics = List<String> oriMetrics = parseInfo.getMetrics().stream().map(SchemaElement::getName)
parseInfo.getMetrics().stream() .collect(Collectors.toList());
.map(SchemaElement::getName)
.collect(Collectors.toList());
String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL(); String correctorSql = parseInfo.getSqlInfo().getCorrectedS2SQL();
log.info("before replaceMetrics:{}", correctorSql); log.info("before replaceMetrics:{}", correctorSql);
log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric); log.info("filteredMetrics:{},metrics:{}", oriMetrics, metric);
@@ -362,20 +344,15 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return queryResult; return queryResult;
} }
private Set<String> updateDateInfo( private Set<String> updateDateInfo(ChatQueryDataReq queryData, SemanticParseInfo parseInfo,
ChatQueryDataReq queryData, DataSetSchema dataSetSchema, Map<String, Map<String, String>> filedNameToValueMap,
SemanticParseInfo parseInfo, List<FieldExpression> fieldExpressionList, List<Expression> addConditions) {
DataSetSchema dataSetSchema,
Map<String, Map<String, String>> filedNameToValueMap,
List<FieldExpression> fieldExpressionList,
List<Expression> addConditions) {
Set<String> removeFieldNames = new HashSet<>(); Set<String> removeFieldNames = new HashSet<>();
if (Objects.isNull(queryData.getDateInfo())) { if (Objects.isNull(queryData.getDateInfo())) {
return removeFieldNames; return removeFieldNames;
} }
if (queryData.getDateInfo().getUnit() > 1) { if (queryData.getDateInfo().getUnit() > 1) {
queryData queryData.getDateInfo()
.getDateInfo()
.setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1)); .setStartDate(DateUtils.getBeforeDate(queryData.getDateInfo().getUnit() + 1));
queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(0)); queryData.getDateInfo().setEndDate(DateUtils.getBeforeDate(0));
} }
@@ -386,16 +363,10 @@ public class ChatQueryServiceImpl implements ChatQueryService {
// first remove,then add // first remove,then add
removeFieldNames.add(partitionDimension.getName()); removeFieldNames.add(partitionDimension.getName());
GreaterThanEquals greaterThanEquals = new GreaterThanEquals(); GreaterThanEquals greaterThanEquals = new GreaterThanEquals();
addTimeFilters( addTimeFilters(queryData.getDateInfo().getStartDate(), greaterThanEquals,
queryData.getDateInfo().getStartDate(), addConditions, partitionDimension);
greaterThanEquals,
addConditions,
partitionDimension);
MinorThanEquals minorThanEquals = new MinorThanEquals(); MinorThanEquals minorThanEquals = new MinorThanEquals();
addTimeFilters( addTimeFilters(queryData.getDateInfo().getEndDate(), minorThanEquals, addConditions,
queryData.getDateInfo().getEndDate(),
minorThanEquals,
addConditions,
partitionDimension); partitionDimension);
break; break;
} }
@@ -403,8 +374,7 @@ public class ChatQueryServiceImpl implements ChatQueryService {
for (FieldExpression fieldExpression : fieldExpressionList) { for (FieldExpression fieldExpression : fieldExpressionList) {
for (QueryFilter queryFilter : queryData.getDimensionFilters()) { for (QueryFilter queryFilter : queryData.getDimensionFilters()) {
if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE) if (queryFilter.getOperator().equals(FilterOperatorEnum.LIKE)
&& FilterOperatorEnum.LIKE && FilterOperatorEnum.LIKE.getValue()
.getValue()
.equalsIgnoreCase(fieldExpression.getOperator())) { .equalsIgnoreCase(fieldExpression.getOperator())) {
Map<String, String> replaceMap = new HashMap<>(); Map<String, String> replaceMap = new HashMap<>();
String preValue = fieldExpression.getFieldValue().toString(); String preValue = fieldExpression.getFieldValue().toString();
@@ -425,11 +395,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return removeFieldNames; return removeFieldNames;
} }
private <T extends ComparisonOperator> void addTimeFilters( private <T extends ComparisonOperator> void addTimeFilters(String date, T comparisonExpression,
String date, List<Expression> addConditions, SchemaElement partitionDimension) {
T comparisonExpression,
List<Expression> addConditions,
SchemaElement partitionDimension) {
Column column = new Column(partitionDimension.getName()); Column column = new Column(partitionDimension.getName());
StringValue stringValue = new StringValue(date); StringValue stringValue = new StringValue(date);
comparisonExpression.setLeftExpression(column); comparisonExpression.setLeftExpression(column);
@@ -437,10 +404,8 @@ public class ChatQueryServiceImpl implements ChatQueryService {
addConditions.add(comparisonExpression); addConditions.add(comparisonExpression);
} }
private Set<String> updateFilters( private Set<String> updateFilters(List<FieldExpression> fieldExpressionList,
List<FieldExpression> fieldExpressionList, Set<QueryFilter> metricFilters, Set<QueryFilter> contextMetricFilters,
Set<QueryFilter> metricFilters,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
Set<String> removeFieldNames = new HashSet<>(); Set<String> removeFieldNames = new HashSet<>();
if (CollectionUtils.isEmpty(metricFilters)) { if (CollectionUtils.isEmpty(metricFilters)) {
@@ -460,15 +425,13 @@ public class ChatQueryServiceImpl implements ChatQueryService {
return removeFieldNames; return removeFieldNames;
} }
private void handleFilter( private void handleFilter(QueryFilter dslQueryFilter, Set<QueryFilter> contextMetricFilters,
QueryFilter dslQueryFilter,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
FilterOperatorEnum operator = dslQueryFilter.getOperator(); FilterOperatorEnum operator = dslQueryFilter.getOperator();
if (operator == FilterOperatorEnum.IN) { if (operator == FilterOperatorEnum.IN) {
addWhereInFilters( addWhereInFilters(dslQueryFilter, new InExpression(), contextMetricFilters,
dslQueryFilter, new InExpression(), contextMetricFilters, addConditions); addConditions);
} else { } else {
ComparisonOperator expression = FilterOperatorEnum.createExpression(operator); ComparisonOperator expression = FilterOperatorEnum.createExpression(operator);
if (Objects.nonNull(expression)) { if (Objects.nonNull(expression)) {
@@ -477,12 +440,9 @@ public class ChatQueryServiceImpl implements ChatQueryService {
} }
} }
// add in condition to sql where condition // add in condition to sql where condition
private void addWhereInFilters( private void addWhereInFilters(QueryFilter dslQueryFilter, InExpression inExpression,
QueryFilter dslQueryFilter, Set<QueryFilter> contextMetricFilters, List<Expression> addConditions) {
InExpression inExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) {
Column column = new Column(dslQueryFilter.getName()); Column column = new Column(dslQueryFilter.getName());
ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>(); ParenthesedExpressionList parenthesedExpressionList = new ParenthesedExpressionList<>();
List<String> valueList = List<String> valueList =
@@ -490,30 +450,24 @@ public class ChatQueryServiceImpl implements ChatQueryService {
if (CollectionUtils.isEmpty(valueList)) { if (CollectionUtils.isEmpty(valueList)) {
return; return;
} }
valueList.stream() valueList.stream().forEach(o -> {
.forEach( StringValue stringValue = new StringValue(o);
o -> { parenthesedExpressionList.add(stringValue);
StringValue stringValue = new StringValue(o); });
parenthesedExpressionList.add(stringValue);
});
inExpression.setLeftExpression(column); inExpression.setLeftExpression(column);
inExpression.setRightExpression(parenthesedExpressionList); inExpression.setRightExpression(parenthesedExpressionList);
addConditions.add(inExpression); addConditions.add(inExpression);
contextMetricFilters.stream() contextMetricFilters.stream().forEach(o -> {
.forEach( if (o.getName().equals(dslQueryFilter.getName())) {
o -> { o.setValue(dslQueryFilter.getValue());
if (o.getName().equals(dslQueryFilter.getName())) { o.setOperator(dslQueryFilter.getOperator());
o.setValue(dslQueryFilter.getValue()); }
o.setOperator(dslQueryFilter.getOperator()); });
}
});
} }
// add where filter // add where filter
private void addWhereFilters( private void addWhereFilters(QueryFilter dslQueryFilter,
QueryFilter dslQueryFilter, ComparisonOperator comparisonExpression, Set<QueryFilter> contextMetricFilters,
ComparisonOperator comparisonExpression,
Set<QueryFilter> contextMetricFilters,
List<Expression> addConditions) { List<Expression> addConditions) {
String columnName = dslQueryFilter.getName(); String columnName = dslQueryFilter.getName();
if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) { if (StringUtils.isNotBlank(dslQueryFilter.getFunction())) {
@@ -533,18 +487,16 @@ public class ChatQueryServiceImpl implements ChatQueryService {
comparisonExpression.setRightExpression(stringValue); comparisonExpression.setRightExpression(stringValue);
} }
addConditions.add(comparisonExpression); addConditions.add(comparisonExpression);
contextMetricFilters.stream() contextMetricFilters.stream().forEach(o -> {
.forEach( if (o.getName().equals(dslQueryFilter.getName())) {
o -> { o.setValue(dslQueryFilter.getValue());
if (o.getName().equals(dslQueryFilter.getName())) { o.setOperator(dslQueryFilter.getOperator());
o.setValue(dslQueryFilter.getValue()); }
o.setOperator(dslQueryFilter.getOperator()); });
}
});
} }
private SemanticParseInfo mergeParseInfo( private SemanticParseInfo mergeParseInfo(SemanticParseInfo parseInfo,
SemanticParseInfo parseInfo, ChatQueryDataReq queryData) { ChatQueryDataReq queryData) {
if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) { if (LLMSqlQuery.QUERY_MODE.equals(parseInfo.getQueryMode())) {
return parseInfo; return parseInfo;
} }

View File

@@ -51,10 +51,8 @@ public class ConfigServiceImpl implements ConfigService {
private final ChatConfigHelper chatConfigHelper; private final ChatConfigHelper chatConfigHelper;
private final SemanticLayerService semanticLayerService; private final SemanticLayerService semanticLayerService;
public ConfigServiceImpl( public ConfigServiceImpl(ChatConfigRepository chatConfigRepository,
ChatConfigRepository chatConfigRepository, ChatConfigHelper chatConfigHelper, SemanticLayerService semanticLayerService) {
ChatConfigHelper chatConfigHelper,
SemanticLayerService semanticLayerService) {
this.chatConfigRepository = chatConfigRepository; this.chatConfigRepository = chatConfigRepository;
this.chatConfigHelper = chatConfigHelper; this.chatConfigHelper = chatConfigHelper;
this.semanticLayerService = semanticLayerService; this.semanticLayerService = semanticLayerService;
@@ -80,9 +78,8 @@ public class ConfigServiceImpl implements ConfigService {
@Override @Override
public Long editConfig(ChatConfigEditReqReq configEditCmd, User user) { public Long editConfig(ChatConfigEditReqReq configEditCmd, User user) {
log.info("[edit model extend] object:{}", JsonUtil.toString(configEditCmd, true)); log.info("[edit model extend] object:{}", JsonUtil.toString(configEditCmd, true));
if (Objects.isNull(configEditCmd) if (Objects.isNull(configEditCmd) || Objects.isNull(configEditCmd.getId())
|| Objects.isNull(configEditCmd.getId()) && Objects.isNull(configEditCmd.getModelId())) {
&& Objects.isNull(configEditCmd.getModelId())) {
throw new RuntimeException( throw new RuntimeException(
"editConfig, id and modelId are not allowed to be empty at the same time"); "editConfig, id and modelId are not allowed to be empty at the same time");
} }
@@ -107,13 +104,13 @@ public class ConfigServiceImpl implements ConfigService {
List<Long> blackDimIdList = new ArrayList<>(); List<Long> blackDimIdList = new ArrayList<>();
if (Objects.nonNull(chatConfig.getChatAggConfig()) if (Objects.nonNull(chatConfig.getChatAggConfig())
&& Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) { && Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) {
blackDimIdList.addAll( blackDimIdList
chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList()); .addAll(chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList());
} }
if (Objects.nonNull(chatConfig.getChatDetailConfig()) if (Objects.nonNull(chatConfig.getChatDetailConfig())
&& Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) { && Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) {
blackDimIdList.addAll( blackDimIdList
chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList()); .addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList());
} }
List<Long> filterDimIdList = List<Long> filterDimIdList =
blackDimIdList.stream().distinct().collect(Collectors.toList()); blackDimIdList.stream().distinct().collect(Collectors.toList());
@@ -121,8 +118,8 @@ public class ConfigServiceImpl implements ConfigService {
List<Long> blackMetricIdList = new ArrayList<>(); List<Long> blackMetricIdList = new ArrayList<>();
if (Objects.nonNull(chatConfig.getChatAggConfig()) if (Objects.nonNull(chatConfig.getChatAggConfig())
&& Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) { && Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) {
blackMetricIdList.addAll( blackMetricIdList
chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList()); .addAll(chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList());
} }
if (Objects.nonNull(chatConfig.getChatDetailConfig()) if (Objects.nonNull(chatConfig.getChatDetailConfig())
&& Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) { && Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) {
@@ -138,20 +135,16 @@ public class ConfigServiceImpl implements ConfigService {
if (!CollectionUtils.isEmpty(blackDimIdList)) { if (!CollectionUtils.isEmpty(blackDimIdList)) {
List<DimensionResp> dimensionRespList = semanticLayerService.getDimensions(metaFilter); List<DimensionResp> dimensionRespList = semanticLayerService.getDimensions(metaFilter);
List<String> blackDimNameList = List<String> blackDimNameList =
dimensionRespList.stream() dimensionRespList.stream().filter(o -> filterDimIdList.contains(o.getId()))
.filter(o -> filterDimIdList.contains(o.getId())) .map(SchemaItem::getName).collect(Collectors.toList());
.map(SchemaItem::getName)
.collect(Collectors.toList());
itemNameVisibility.setBlackDimNameList(blackDimNameList); itemNameVisibility.setBlackDimNameList(blackDimNameList);
} }
if (!CollectionUtils.isEmpty(blackMetricIdList)) { if (!CollectionUtils.isEmpty(blackMetricIdList)) {
List<MetricResp> metricRespList = semanticLayerService.getMetrics(metaFilter); List<MetricResp> metricRespList = semanticLayerService.getMetrics(metaFilter);
List<String> blackMetricList = List<String> blackMetricList =
metricRespList.stream() metricRespList.stream().filter(o -> filterMetricIdList.contains(o.getId()))
.filter(o -> filterMetricIdList.contains(o.getId())) .map(SchemaItem::getName).collect(Collectors.toList());
.map(SchemaItem::getName)
.collect(Collectors.toList());
itemNameVisibility.setBlackMetricNameList(blackMetricList); itemNameVisibility.setBlackMetricNameList(blackMetricList);
} }
return itemNameVisibility; return itemNameVisibility;
@@ -169,8 +162,8 @@ public class ConfigServiceImpl implements ConfigService {
return chatConfigRepository.getConfigByModelId(modelId); return chatConfigRepository.getConfigByModelId(modelId);
} }
private ItemVisibilityInfo fetchVisibilityDescByConfig( private ItemVisibilityInfo fetchVisibilityDescByConfig(ItemVisibility visibility,
ItemVisibility visibility, DataSetSchema modelSchema) { DataSetSchema modelSchema) {
ItemVisibilityInfo itemVisibilityDesc = new ItemVisibilityInfo(); ItemVisibilityInfo itemVisibilityDesc = new ItemVisibilityInfo();
List<Long> dimIdAllList = chatConfigHelper.generateAllDimIdList(modelSchema); List<Long> dimIdAllList = chatConfigHelper.generateAllDimIdList(modelSchema);
@@ -186,17 +179,12 @@ public class ConfigServiceImpl implements ConfigService {
blackMetricIdList.addAll(visibility.getBlackMetricIdList()); blackMetricIdList.addAll(visibility.getBlackMetricIdList());
} }
} }
List<Long> whiteMetricIdList = List<Long> whiteMetricIdList = metricIdAllList.stream()
metricIdAllList.stream() .filter(id -> !blackMetricIdList.contains(id) && metricIdAllList.contains(id))
.filter( .collect(Collectors.toList());
id -> List<Long> whiteDimIdList = dimIdAllList.stream()
!blackMetricIdList.contains(id) .filter(id -> !blackDimIdList.contains(id) && dimIdAllList.contains(id))
&& metricIdAllList.contains(id)) .collect(Collectors.toList());
.collect(Collectors.toList());
List<Long> whiteDimIdList =
dimIdAllList.stream()
.filter(id -> !blackDimIdList.contains(id) && dimIdAllList.contains(id))
.collect(Collectors.toList());
itemVisibilityDesc.setBlackDimIdList(blackDimIdList); itemVisibilityDesc.setBlackDimIdList(blackDimIdList);
itemVisibilityDesc.setBlackMetricIdList(blackMetricIdList); itemVisibilityDesc.setBlackMetricIdList(blackMetricIdList);
@@ -232,10 +220,8 @@ public class ConfigServiceImpl implements ConfigService {
return chatConfigRich; return chatConfigRich;
} }
private ChatDetailRichConfigResp fillChatDetailRichConfig( private ChatDetailRichConfigResp fillChatDetailRichConfig(DataSetSchema modelSchema,
DataSetSchema modelSchema, ChatConfigRichResp chatConfigRich, ChatConfigResp chatConfigResp) {
ChatConfigRichResp chatConfigRich,
ChatConfigResp chatConfigResp) {
if (Objects.isNull(chatConfigResp) if (Objects.isNull(chatConfigResp)
|| Objects.isNull(chatConfigResp.getChatDetailConfig())) { || Objects.isNull(chatConfigResp.getChatDetailConfig())) {
return null; return null;
@@ -248,9 +234,8 @@ public class ConfigServiceImpl implements ConfigService {
detailRichConfig.setKnowledgeInfos( detailRichConfig.setKnowledgeInfos(
fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), modelSchema)); fillKnowledgeBizName(chatDetailConfig.getKnowledgeInfos(), modelSchema));
detailRichConfig.setGlobalKnowledgeConfig(chatDetailConfig.getGlobalKnowledgeConfig()); detailRichConfig.setGlobalKnowledgeConfig(chatDetailConfig.getGlobalKnowledgeConfig());
detailRichConfig.setChatDefaultConfig( detailRichConfig.setChatDefaultConfig(fetchDefaultConfig(
fetchDefaultConfig( chatDetailConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
chatDetailConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
return detailRichConfig; return detailRichConfig;
} }
@@ -261,18 +246,15 @@ public class ConfigServiceImpl implements ConfigService {
return entityRichInfo; return entityRichInfo;
} }
BeanUtils.copyProperties(entity, entityRichInfo); BeanUtils.copyProperties(entity, entityRichInfo);
Map<Long, SchemaElement> dimIdAndRespPair = Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream().collect(
modelSchema.getDimensions().stream() Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
.collect(
Collectors.toMap(
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
entityRichInfo.setDimItem(dimIdAndRespPair.get(entity.getEntityId())); entityRichInfo.setDimItem(dimIdAndRespPair.get(entity.getEntityId()));
return entityRichInfo; return entityRichInfo;
} }
private ChatAggRichConfigResp fillChatAggRichConfig( private ChatAggRichConfigResp fillChatAggRichConfig(DataSetSchema modelSchema,
DataSetSchema modelSchema, ChatConfigResp chatConfigResp) { ChatConfigResp chatConfigResp) {
if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) { if (Objects.isNull(chatConfigResp) || Objects.isNull(chatConfigResp.getChatAggConfig())) {
return null; return null;
} }
@@ -284,72 +266,53 @@ public class ConfigServiceImpl implements ConfigService {
chatAggRichConfig.setKnowledgeInfos( chatAggRichConfig.setKnowledgeInfos(
fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), modelSchema)); fillKnowledgeBizName(chatAggConfig.getKnowledgeInfos(), modelSchema));
chatAggRichConfig.setGlobalKnowledgeConfig(chatAggConfig.getGlobalKnowledgeConfig()); chatAggRichConfig.setGlobalKnowledgeConfig(chatAggConfig.getGlobalKnowledgeConfig());
chatAggRichConfig.setChatDefaultConfig( chatAggRichConfig.setChatDefaultConfig(fetchDefaultConfig(
fetchDefaultConfig( chatAggConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
chatAggConfig.getChatDefaultConfig(), modelSchema, itemVisibilityInfo));
return chatAggRichConfig; return chatAggRichConfig;
} }
private ChatDefaultRichConfigResp fetchDefaultConfig( private ChatDefaultRichConfigResp fetchDefaultConfig(ChatDefaultConfigReq chatDefaultConfig,
ChatDefaultConfigReq chatDefaultConfig, DataSetSchema modelSchema, ItemVisibilityInfo itemVisibilityInfo) {
DataSetSchema modelSchema,
ItemVisibilityInfo itemVisibilityInfo) {
ChatDefaultRichConfigResp defaultRichConfig = new ChatDefaultRichConfigResp(); ChatDefaultRichConfigResp defaultRichConfig = new ChatDefaultRichConfigResp();
if (Objects.isNull(chatDefaultConfig)) { if (Objects.isNull(chatDefaultConfig)) {
return defaultRichConfig; return defaultRichConfig;
} }
BeanUtils.copyProperties(chatDefaultConfig, defaultRichConfig); BeanUtils.copyProperties(chatDefaultConfig, defaultRichConfig);
Map<Long, SchemaElement> dimIdAndRespPair = Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream().collect(
modelSchema.getDimensions().stream() Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
.collect(
Collectors.toMap(
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
Map<Long, SchemaElement> metricIdAndRespPair = Map<Long, SchemaElement> metricIdAndRespPair = modelSchema.getMetrics().stream().collect(
modelSchema.getMetrics().stream() Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
.collect(
Collectors.toMap(
SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
List<SchemaElement> dimensions = new ArrayList<>(); List<SchemaElement> dimensions = new ArrayList<>();
List<SchemaElement> metrics = new ArrayList<>(); List<SchemaElement> metrics = new ArrayList<>();
if (!CollectionUtils.isEmpty(chatDefaultConfig.getDimensionIds())) { if (!CollectionUtils.isEmpty(chatDefaultConfig.getDimensionIds())) {
chatDefaultConfig.getDimensionIds().stream() chatDefaultConfig.getDimensionIds().stream()
.filter( .filter(dimId -> dimIdAndRespPair.containsKey(dimId)
dimId -> && itemVisibilityInfo.getWhiteDimIdList().contains(dimId))
dimIdAndRespPair.containsKey(dimId) .forEach(dimId -> {
&& itemVisibilityInfo SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId);
.getWhiteDimIdList() if (Objects.nonNull(dimSchemaResp)) {
.contains(dimId)) SchemaElement dimSchema = new SchemaElement();
.forEach( BeanUtils.copyProperties(dimSchemaResp, dimSchema);
dimId -> { dimensions.add(dimSchema);
SchemaElement dimSchemaResp = dimIdAndRespPair.get(dimId); }
if (Objects.nonNull(dimSchemaResp)) { });
SchemaElement dimSchema = new SchemaElement();
BeanUtils.copyProperties(dimSchemaResp, dimSchema);
dimensions.add(dimSchema);
}
});
} }
if (!CollectionUtils.isEmpty(chatDefaultConfig.getMetricIds())) { if (!CollectionUtils.isEmpty(chatDefaultConfig.getMetricIds())) {
chatDefaultConfig.getMetricIds().stream() chatDefaultConfig.getMetricIds().stream()
.filter( .filter(metricId -> metricIdAndRespPair.containsKey(metricId)
metricId -> && itemVisibilityInfo.getWhiteMetricIdList().contains(metricId))
metricIdAndRespPair.containsKey(metricId) .forEach(metricId -> {
&& itemVisibilityInfo SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId);
.getWhiteMetricIdList() if (Objects.nonNull(metricSchemaResp)) {
.contains(metricId)) SchemaElement metricSchema = new SchemaElement();
.forEach( BeanUtils.copyProperties(metricSchemaResp, metricSchema);
metricId -> { metrics.add(metricSchema);
SchemaElement metricSchemaResp = metricIdAndRespPair.get(metricId); }
if (Objects.nonNull(metricSchemaResp)) { });
SchemaElement metricSchema = new SchemaElement();
BeanUtils.copyProperties(metricSchemaResp, metricSchema);
metrics.add(metricSchema);
}
});
} }
defaultRichConfig.setDimensions(dimensions); defaultRichConfig.setDimensions(dimensions);
@@ -357,27 +320,21 @@ public class ConfigServiceImpl implements ConfigService {
return defaultRichConfig; return defaultRichConfig;
} }
private List<KnowledgeInfoReq> fillKnowledgeBizName( private List<KnowledgeInfoReq> fillKnowledgeBizName(List<KnowledgeInfoReq> knowledgeInfos,
List<KnowledgeInfoReq> knowledgeInfos, DataSetSchema modelSchema) { DataSetSchema modelSchema) {
if (CollectionUtils.isEmpty(knowledgeInfos)) { if (CollectionUtils.isEmpty(knowledgeInfos)) {
return new ArrayList<>(); return new ArrayList<>();
} }
Map<Long, SchemaElement> dimIdAndRespPair = Map<Long, SchemaElement> dimIdAndRespPair = modelSchema.getDimensions().stream().collect(
modelSchema.getDimensions().stream() Collectors.toMap(SchemaElement::getId, Function.identity(), (k1, k2) -> k1));
.collect( knowledgeInfos.stream().forEach(knowledgeInfo -> {
Collectors.toMap( if (Objects.nonNull(knowledgeInfo)) {
SchemaElement::getId, Function.identity(), (k1, k2) -> k1)); SchemaElement dimSchemaResp = dimIdAndRespPair.get(knowledgeInfo.getItemId());
knowledgeInfos.stream() if (Objects.nonNull(dimSchemaResp)) {
.forEach( knowledgeInfo.setBizName(dimSchemaResp.getBizName());
knowledgeInfo -> { }
if (Objects.nonNull(knowledgeInfo)) { }
SchemaElement dimSchemaResp = });
dimIdAndRespPair.get(knowledgeInfo.getItemId());
if (Objects.nonNull(dimSchemaResp)) {
knowledgeInfo.setBizName(dimSchemaResp.getBizName());
}
}
});
return knowledgeInfos; return knowledgeInfos;
} }

View File

@@ -25,11 +25,14 @@ import java.util.List;
@Service @Service
public class MemoryServiceImpl implements MemoryService { public class MemoryServiceImpl implements MemoryService {
@Autowired private ChatMemoryRepository chatMemoryRepository; @Autowired
private ChatMemoryRepository chatMemoryRepository;
@Autowired private ExemplarService exemplarService; @Autowired
private ExemplarService exemplarService;
@Autowired private EmbeddingConfig embeddingConfig; @Autowired
private EmbeddingConfig embeddingConfig;
@Override @Override
public void createMemory(ChatMemoryDO memory) { public void createMemory(ChatMemoryDO memory) {
@@ -85,20 +88,18 @@ public class MemoryServiceImpl implements MemoryService {
queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus()); queryWrapper.lambda().eq(ChatMemoryDO::getStatus, chatMemoryFilter.getStatus());
} }
if (chatMemoryFilter.getHumanReviewRet() != null) { if (chatMemoryFilter.getHumanReviewRet() != null) {
queryWrapper queryWrapper.lambda().eq(ChatMemoryDO::getHumanReviewRet,
.lambda() chatMemoryFilter.getHumanReviewRet());
.eq(ChatMemoryDO::getHumanReviewRet, chatMemoryFilter.getHumanReviewRet());
} }
if (chatMemoryFilter.getLlmReviewRet() != null) { if (chatMemoryFilter.getLlmReviewRet() != null) {
queryWrapper queryWrapper.lambda().eq(ChatMemoryDO::getLlmReviewRet,
.lambda() chatMemoryFilter.getLlmReviewRet());
.eq(ChatMemoryDO::getLlmReviewRet, chatMemoryFilter.getLlmReviewRet());
} }
if (StringUtils.isBlank(chatMemoryFilter.getOrderCondition())) { if (StringUtils.isBlank(chatMemoryFilter.getOrderCondition())) {
queryWrapper.orderByDesc("id"); queryWrapper.orderByDesc("id");
} else { } else {
queryWrapper.orderBy( queryWrapper.orderBy(true, chatMemoryFilter.isAsc(),
true, chatMemoryFilter.isAsc(), chatMemoryFilter.getOrderCondition()); chatMemoryFilter.getOrderCondition());
} }
return chatMemoryRepository.getMemories(queryWrapper); return chatMemoryRepository.getMemories(queryWrapper);
} }
@@ -106,9 +107,7 @@ public class MemoryServiceImpl implements MemoryService {
@Override @Override
public List<ChatMemoryDO> getMemoriesForLlmReview() { public List<ChatMemoryDO> getMemoriesForLlmReview() {
QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>(); QueryWrapper<ChatMemoryDO> queryWrapper = new QueryWrapper<>();
queryWrapper queryWrapper.lambda().eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
.lambda()
.eq(ChatMemoryDO::getStatus, MemoryStatus.PENDING)
.isNull(ChatMemoryDO::getLlmReviewRet); .isNull(ChatMemoryDO::getLlmReviewRet);
return chatMemoryRepository.getMemories(queryWrapper); return chatMemoryRepository.getMemories(queryWrapper);
} }
@@ -116,26 +115,18 @@ public class MemoryServiceImpl implements MemoryService {
@Override @Override
public void enableMemory(ChatMemoryDO memory) { public void enableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.ENABLED); memory.setStatus(MemoryStatus.ENABLED);
exemplarService.storeExemplar( exemplarService.storeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
embeddingConfig.getMemoryCollectionName(memory.getAgentId()), Text2SQLExemplar.builder().question(memory.getQuestion())
Text2SQLExemplar.builder() .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.question(memory.getQuestion()) .sql(memory.getS2sql()).build());
.sideInfo(memory.getSideInfo())
.dbSchema(memory.getDbSchema())
.sql(memory.getS2sql())
.build());
} }
@Override @Override
public void disableMemory(ChatMemoryDO memory) { public void disableMemory(ChatMemoryDO memory) {
memory.setStatus(MemoryStatus.DISABLED); memory.setStatus(MemoryStatus.DISABLED);
exemplarService.removeExemplar( exemplarService.removeExemplar(embeddingConfig.getMemoryCollectionName(memory.getAgentId()),
embeddingConfig.getMemoryCollectionName(memory.getAgentId()), Text2SQLExemplar.builder().question(memory.getQuestion())
Text2SQLExemplar.builder() .sideInfo(memory.getSideInfo()).dbSchema(memory.getDbSchema())
.question(memory.getQuestion()) .sql(memory.getS2sql()).build());
.sideInfo(memory.getSideInfo())
.dbSchema(memory.getDbSchema())
.sql(memory.getS2sql())
.build());
} }
} }

View File

@@ -36,8 +36,8 @@ public class PluginServiceImpl implements PluginService {
private ApplicationEventPublisher publisher; private ApplicationEventPublisher publisher;
public PluginServiceImpl( public PluginServiceImpl(PluginRepository pluginRepository,
PluginRepository pluginRepository, ApplicationEventPublisher publisher) { ApplicationEventPublisher publisher) {
this.pluginRepository = pluginRepository; this.pluginRepository = pluginRepository;
this.publisher = publisher; this.publisher = publisher;
} }
@@ -110,18 +110,11 @@ public class PluginServiceImpl implements PluginService {
} }
List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper); List<PluginDO> pluginDOS = pluginRepository.query(queryWrapper);
if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) { if (StringUtils.isNotBlank(pluginQueryReq.getPattern())) {
pluginDOS = pluginDOS = pluginDOS.stream()
pluginDOS.stream() .filter(pluginDO -> pluginDO.getPattern().contains(pluginQueryReq.getPattern())
.filter( || (pluginDO.getName() != null
pluginDO -> && pluginDO.getName().contains(pluginQueryReq.getPattern())))
pluginDO.getPattern() .collect(Collectors.toList());
.contains(pluginQueryReq.getPattern())
|| (pluginDO.getName() != null
&& pluginDO.getName()
.contains(
pluginQueryReq
.getPattern())))
.collect(Collectors.toList());
} }
return convertList(pluginDOS); return convertList(pluginDOS);
} }
@@ -129,16 +122,13 @@ public class PluginServiceImpl implements PluginService {
@Override @Override
public Optional<ChatPlugin> getPluginByName(String name) { public Optional<ChatPlugin> getPluginByName(String name) {
log.info("name:{}", name); log.info("name:{}", name);
return getPluginList().stream() return getPluginList().stream().filter(plugin -> {
.filter( PluginParseConfig functionCallConfig = getPluginParseConfig(plugin);
plugin -> { if (functionCallConfig == null) {
PluginParseConfig functionCallConfig = getPluginParseConfig(plugin); return false;
if (functionCallConfig == null) { }
return false; return functionCallConfig.getName().equalsIgnoreCase(name);
} }).findFirst();
return functionCallConfig.getName().equalsIgnoreCase(name);
})
.findFirst();
} }
private PluginParseConfig getPluginParseConfig(ChatPlugin plugin) { private PluginParseConfig getPluginParseConfig(ChatPlugin plugin) {
@@ -166,26 +156,17 @@ public class PluginServiceImpl implements PluginService {
public Map<String, ChatPlugin> getNameToPlugin() { public Map<String, ChatPlugin> getNameToPlugin() {
List<ChatPlugin> pluginList = getPluginList(); List<ChatPlugin> pluginList = getPluginList();
return pluginList.stream() return pluginList.stream().filter(plugin -> {
.filter( PluginParseConfig functionCallConfig = getPluginParseConfig(plugin);
plugin -> { if (functionCallConfig == null) {
PluginParseConfig functionCallConfig = getPluginParseConfig(plugin); return false;
if (functionCallConfig == null) { }
return false; return true;
} }).collect(Collectors.toMap(a -> {
return true; PluginParseConfig functionCallConfig =
}) JsonUtil.toObject(a.getParseModeConfig(), PluginParseConfig.class);
.collect( return functionCallConfig.getName();
Collectors.toMap( }, a -> a, (k1, k2) -> k1));
a -> {
PluginParseConfig functionCallConfig =
JsonUtil.toObject(
a.getParseModeConfig(),
PluginParseConfig.class);
return functionCallConfig.getName();
},
a -> a,
(k1, k2) -> k1));
} }
// todo // todo
@@ -197,10 +178,8 @@ public class PluginServiceImpl implements PluginService {
ChatPlugin plugin = new ChatPlugin(); ChatPlugin plugin = new ChatPlugin();
BeanUtils.copyProperties(pluginDO, plugin); BeanUtils.copyProperties(pluginDO, plugin);
if (StringUtils.isNotBlank(pluginDO.getDataSet())) { if (StringUtils.isNotBlank(pluginDO.getDataSet())) {
plugin.setDataSetList( plugin.setDataSetList(Arrays.stream(pluginDO.getDataSet().split(","))
Arrays.stream(pluginDO.getDataSet().split(",")) .map(Long::parseLong).collect(Collectors.toList()));
.map(Long::parseLong)
.collect(Collectors.toList()));
} }
return plugin; return plugin;
} }

View File

@@ -14,7 +14,8 @@ import java.util.List;
@Slf4j @Slf4j
public class StatisticsServiceImpl implements StatisticsService { public class StatisticsServiceImpl implements StatisticsService {
@Autowired private StatisticsMapper statisticsMapper; @Autowired
private StatisticsMapper statisticsMapper;
@Async @Async
@Override @Override

View File

@@ -37,10 +37,8 @@ public class ChatConfigHelper {
ChatConfig chatConfig = new ChatConfig(); ChatConfig chatConfig = new ChatConfig();
BeanUtils.copyProperties(extendBaseCmd, chatConfig); BeanUtils.copyProperties(extendBaseCmd, chatConfig);
RecordInfo recordInfo = new RecordInfo(); RecordInfo recordInfo = new RecordInfo();
String creator = String creator = (Objects.isNull(user) || StringUtils.isEmpty(user.getName())) ? ADMIN_LOWER
(Objects.isNull(user) || StringUtils.isEmpty(user.getName())) : user.getName();
? ADMIN_LOWER
: user.getName();
recordInfo.createdBy(creator); recordInfo.createdBy(creator);
chatConfig.setRecordInfo(recordInfo); chatConfig.setRecordInfo(recordInfo);
chatConfig.setStatus(StatusEnum.ONLINE); chatConfig.setStatus(StatusEnum.ONLINE);
@@ -52,10 +50,9 @@ public class ChatConfigHelper {
BeanUtils.copyProperties(extendEditCmd, chatConfig); BeanUtils.copyProperties(extendEditCmd, chatConfig);
RecordInfo recordInfo = new RecordInfo(); RecordInfo recordInfo = new RecordInfo();
String user = String user = (Objects.isNull(facadeUser) || StringUtils.isEmpty(facadeUser.getName()))
(Objects.isNull(facadeUser) || StringUtils.isEmpty(facadeUser.getName())) ? ADMIN_LOWER
? ADMIN_LOWER : facadeUser.getName();
: facadeUser.getName();
recordInfo.updatedBy(user); recordInfo.updatedBy(user);
chatConfig.setRecordInfo(recordInfo); chatConfig.setRecordInfo(recordInfo);
return chatConfig; return chatConfig;
@@ -65,9 +62,8 @@ public class ChatConfigHelper {
if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getDimensions())) { if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getDimensions())) {
return new ArrayList<>(); return new ArrayList<>();
} }
Map<Long, List<SchemaElement>> dimIdAndDescPair = Map<Long, List<SchemaElement>> dimIdAndDescPair = modelSchema.getDimensions().stream()
modelSchema.getDimensions().stream() .collect(Collectors.groupingBy(SchemaElement::getId));
.collect(Collectors.groupingBy(SchemaElement::getId));
return new ArrayList<>(dimIdAndDescPair.keySet()); return new ArrayList<>(dimIdAndDescPair.keySet());
} }
@@ -75,9 +71,8 @@ public class ChatConfigHelper {
if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getMetrics())) { if (Objects.isNull(modelSchema) || CollectionUtils.isEmpty(modelSchema.getMetrics())) {
return new ArrayList<>(); return new ArrayList<>();
} }
Map<Long, List<SchemaElement>> metricIdAndDescPair = Map<Long, List<SchemaElement>> metricIdAndDescPair = modelSchema.getMetrics().stream()
modelSchema.getMetrics().stream() .collect(Collectors.groupingBy(SchemaElement::getId));
.collect(Collectors.groupingBy(SchemaElement::getId));
return new ArrayList<>(metricIdAndDescPair.keySet()); return new ArrayList<>(metricIdAndDescPair.keySet());
} }
@@ -87,8 +82,8 @@ public class ChatConfigHelper {
chatConfigDO.setChatAggConfig(JsonUtil.toString(chatConfig.getChatAggConfig())); chatConfigDO.setChatAggConfig(JsonUtil.toString(chatConfig.getChatAggConfig()));
chatConfigDO.setChatDetailConfig(JsonUtil.toString(chatConfig.getChatDetailConfig())); chatConfigDO.setChatDetailConfig(JsonUtil.toString(chatConfig.getChatDetailConfig()));
chatConfigDO.setRecommendedQuestions( chatConfigDO
JsonUtil.toString(chatConfig.getRecommendedQuestions())); .setRecommendedQuestions(JsonUtil.toString(chatConfig.getRecommendedQuestions()));
if (Objects.isNull(chatConfig.getStatus())) { if (Objects.isNull(chatConfig.getStatus())) {
chatConfigDO.setStatus(null); chatConfigDO.setStatus(null);
@@ -118,9 +113,8 @@ public class ChatConfigHelper {
JsonUtil.toObject(chatConfigDO.getChatDetailConfig(), ChatDetailConfigReq.class)); JsonUtil.toObject(chatConfigDO.getChatDetailConfig(), ChatDetailConfigReq.class));
chatConfigDescriptor.setChatAggConfig( chatConfigDescriptor.setChatAggConfig(
JsonUtil.toObject(chatConfigDO.getChatAggConfig(), ChatAggConfigReq.class)); JsonUtil.toObject(chatConfigDO.getChatAggConfig(), ChatAggConfigReq.class));
chatConfigDescriptor.setRecommendedQuestions( chatConfigDescriptor.setRecommendedQuestions(JsonUtil
JsonUtil.toList( .toList(chatConfigDO.getRecommendedQuestions(), RecommendedQuestionReq.class));
chatConfigDO.getRecommendedQuestions(), RecommendedQuestionReq.class));
chatConfigDescriptor.setStatusEnum(StatusEnum.of(chatConfigDO.getStatus())); chatConfigDescriptor.setStatusEnum(StatusEnum.of(chatConfigDO.getStatus()));
chatConfigDescriptor.setCreatedBy(chatConfigDO.getCreatedBy()); chatConfigDescriptor.setCreatedBy(chatConfigDO.getCreatedBy());

View File

@@ -51,15 +51,13 @@ public class ComponentFactory {
} }
private static <T> List<T> init(Class<T> factoryType, List list) { private static <T> List<T> init(Class<T> factoryType, List list) {
list.addAll( list.addAll(SpringFactoriesLoader.loadFactories(factoryType,
SpringFactoriesLoader.loadFactories( Thread.currentThread().getContextClassLoader()));
factoryType, Thread.currentThread().getContextClassLoader()));
return list; return list;
} }
private static <T> T init(Class<T> factoryType) { private static <T> T init(Class<T> factoryType) {
return SpringFactoriesLoader.loadFactories( return SpringFactoriesLoader
factoryType, Thread.currentThread().getContextClassLoader()) .loadFactories(factoryType, Thread.currentThread().getContextClassLoader()).get(0);
.get(0);
} }
} }

View File

@@ -8,8 +8,8 @@ import java.util.Map;
public class ResultFormatter { public class ResultFormatter {
public static String transform2TextNew( public static String transform2TextNew(List<QueryColumn> queryColumns,
List<QueryColumn> queryColumns, List<Map<String, Object>> queryResults) { List<Map<String, Object>> queryResults) {
if (CollectionUtils.isEmpty(queryColumns)) { if (CollectionUtils.isEmpty(queryColumns)) {
return ""; return "";
} }

View File

@@ -24,13 +24,12 @@ public class LoadRemoveService {
} }
List<String> resultList = new ArrayList<>(value); List<String> resultList = new ArrayList<>(value);
if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) { if (StringUtils.isNotBlank(mapperRemoveNaturePrefix)) {
resultList.removeIf( resultList.removeIf(nature -> {
nature -> { if (Objects.isNull(nature)) {
if (Objects.isNull(nature)) { return false;
return false; }
} return nature.startsWith(mapperRemoveNaturePrefix);
return nature.startsWith(mapperRemoveNaturePrefix); });
});
} }
return resultList; return resultList;
} }

View File

@@ -253,19 +253,8 @@ public abstract class BaseNode<V> implements Comparable<BaseNode> {
@Override @Override
public String toString() { public String toString() {
return "BaseNode{" return "BaseNode{" + "child=" + Arrays.toString(child) + ", status=" + status + ", c=" + c
+ "child=" + ", value=" + value + ", prefix='" + prefix + '\'' + '}';
+ Arrays.toString(child)
+ ", status="
+ status
+ ", c="
+ c
+ ", value="
+ value
+ ", prefix='"
+ prefix
+ '\''
+ '}';
} }
public void walkNode(Set<Map.Entry<String, V>> entrySet) { public void walkNode(Set<Map.Entry<String, V>> entrySet) {

View File

@@ -34,13 +34,8 @@ public class CoreDictionary {
if (!load(PATH)) { if (!load(PATH)) {
throw new IllegalArgumentException("核心词典" + PATH + "加载失败"); throw new IllegalArgumentException("核心词典" + PATH + "加载失败");
} else { } else {
Predefine.logger.info( Predefine.logger.info(PATH + "加载成功," + trie.size() + "个词条,耗时"
PATH + (System.currentTimeMillis() - start) + "ms");
+ "加载成功,"
+ trie.size()
+ "个词条,耗时"
+ (System.currentTimeMillis() - start)
+ "ms");
} }
} }
@@ -77,22 +72,14 @@ public class CoreDictionary {
map.put(param[0], attribute); map.put(param[0], attribute);
totalFrequency += attribute.totalFrequency; totalFrequency += attribute.totalFrequency;
} }
Predefine.logger.info( Predefine.logger.info("核心词典读入词条" + map.size() + " 全部频次" + totalFrequency + ",耗时"
"核心词典读入词条" + (System.currentTimeMillis() - start) + "ms");
+ map.size()
+ " 全部频次"
+ totalFrequency
+ ",耗时"
+ (System.currentTimeMillis() - start)
+ "ms");
br.close(); br.close();
trie.build(map); trie.build(map);
Predefine.logger.info("核心词典加载成功:" + trie.size() + "个词条,下面将写入缓存……"); Predefine.logger.info("核心词典加载成功:" + trie.size() + "个词条,下面将写入缓存……");
try { try {
DataOutputStream out = DataOutputStream out = new DataOutputStream(
new DataOutputStream( new BufferedOutputStream(IOUtil.newOutputStream(path + Predefine.BIN_EXT)));
new BufferedOutputStream(
IOUtil.newOutputStream(path + Predefine.BIN_EXT)));
Collection<Attribute> attributeList = map.values(); Collection<Attribute> attributeList = map.values();
out.writeInt(attributeList.size()); out.writeInt(attributeList.size());
for (Attribute attribute : attributeList) { for (Attribute attribute : attributeList) {
@@ -278,11 +265,8 @@ public class CoreDictionary {
} }
return attribute; return attribute;
} catch (Exception e) { } catch (Exception e) {
Predefine.logger.warning( Predefine.logger.warning("使用字符串" + natureWithFrequency + "创建词条属性失败!"
"使用字符串" + TextUtility.exceptionToString(e));
+ natureWithFrequency
+ "创建词条属性失败!"
+ TextUtility.exceptionToString(e));
return null; return null;
} }
} }
@@ -409,9 +393,7 @@ public class CoreDictionary {
if (originals == null || originals.length == 0) { if (originals == null || originals.length == 0) {
return null; return null;
} }
return Arrays.stream(originals) return Arrays.stream(originals).filter(o -> o != null).distinct()
.filter(o -> o != null)
.distinct()
.collect(Collectors.toList()); .collect(Collectors.toList());
} }
} }

View File

@@ -47,8 +47,7 @@ public abstract class WordBasedSegment extends Segment {
} }
vertex = (Vertex) var1.next(); vertex = (Vertex) var1.next();
} while (!vertex.realWord.equals("") } while (!vertex.realWord.equals("") && !vertex.realWord.equals("")
&& !vertex.realWord.equals("")
&& !vertex.realWord.equals("-")); && !vertex.realWord.equals("-"));
vertex.confirmNature(Nature.w); vertex.confirmNature(Nature.w);
@@ -66,8 +65,7 @@ public abstract class WordBasedSegment extends Segment {
if (currentNature == Nature.nx if (currentNature == Nature.nx
&& (next.hasNature(Nature.q) || next.hasNature(Nature.n))) { && (next.hasNature(Nature.q) || next.hasNature(Nature.n))) {
String[] param = current.realWord.split("-", 1); String[] param = current.realWord.split("-", 1);
if (param.length == 2 if (param.length == 2 && TextUtility.isAllNum(param[0])
&& TextUtility.isAllNum(param[0])
&& TextUtility.isAllNum(param[1])) { && TextUtility.isAllNum(param[1])) {
current = current.copy(); current = current.copy();
current.realWord = param[0]; current.realWord = param[0];
@@ -112,10 +110,8 @@ public abstract class WordBasedSegment extends Segment {
current.confirmNature(Nature.m, true); current.confirmNature(Nature.m, true);
} else if (current.realWord.length() > 1) { } else if (current.realWord.length() > 1) {
char last = current.realWord.charAt(current.realWord.length() - 1); char last = current.realWord.charAt(current.realWord.length() - 1);
current = current = Vertex.newNumberInstance(
Vertex.newNumberInstance( current.realWord.substring(0, current.realWord.length() - 1));
current.realWord.substring(
0, current.realWord.length() - 1));
listIterator.previous(); listIterator.previous();
listIterator.previous(); listIterator.previous();
listIterator.set(current); listIterator.set(current);
@@ -162,9 +158,7 @@ public abstract class WordBasedSegment extends Segment {
charTypeArray[i] = CharType.get(c); charTypeArray[i] = CharType.get(c);
if (c == '.' && i < charArray.length - 1 && CharType.get(charArray[i + 1]) == 9) { if (c == '.' && i < charArray.length - 1 && CharType.get(charArray[i + 1]) == 9) {
charTypeArray[i] = 9; charTypeArray[i] = 9;
} else if (c == '.' } else if (c == '.' && i < charArray.length - 1 && charArray[i + 1] >= '0'
&& i < charArray.length - 1
&& charArray[i + 1] >= '0'
&& charArray[i + 1] <= '9') { && charArray[i + 1] <= '9') {
charTypeArray[i] = 5; charTypeArray[i] = 5;
} else if (charTypeArray[i] == 8) { } else if (charTypeArray[i] == 8) {
@@ -227,7 +221,7 @@ public abstract class WordBasedSegment extends Segment {
while (listIterator.hasNext()) { while (listIterator.hasNext()) {
next = (Vertex) listIterator.next(); next = (Vertex) listIterator.next();
if (!TextUtility.isAllNum(current.realWord) if (!TextUtility.isAllNum(current.realWord)
&& !TextUtility.isAllChineseNum(current.realWord) && !TextUtility.isAllChineseNum(current.realWord)
|| !TextUtility.isAllNum(next.realWord) || !TextUtility.isAllNum(next.realWord)
&& !TextUtility.isAllChineseNum(next.realWord)) { && !TextUtility.isAllChineseNum(next.realWord)) {
current = next; current = next;
@@ -252,21 +246,16 @@ public abstract class WordBasedSegment extends Segment {
DoubleArrayTrie.Searcher searcher = CoreDictionary.trie.getSearcher(charArray, 0); DoubleArrayTrie.Searcher searcher = CoreDictionary.trie.getSearcher(charArray, 0);
while (searcher.next()) { while (searcher.next()) {
wordNetStorage.add( wordNetStorage.add(searcher.begin + 1,
searcher.begin + 1, new Vertex(new String(charArray, searcher.begin, searcher.length),
new Vertex( (CoreDictionary.Attribute) searcher.value, searcher.index));
new String(charArray, searcher.begin, searcher.length),
(CoreDictionary.Attribute) searcher.value,
searcher.index));
} }
if (this.config.forceCustomDictionary) { if (this.config.forceCustomDictionary) {
this.customDictionary.parseText( this.customDictionary.parseText(charArray,
charArray,
new AhoCorasickDoubleArrayTrie.IHit<CoreDictionary.Attribute>() { new AhoCorasickDoubleArrayTrie.IHit<CoreDictionary.Attribute>() {
public void hit(int begin, int end, CoreDictionary.Attribute value) { public void hit(int begin, int end, CoreDictionary.Attribute value) {
wordNetStorage.add( wordNetStorage.add(begin + 1,
begin + 1,
new Vertex(new String(charArray, begin, end - begin), value)); new Vertex(new String(charArray, begin, end - begin), value));
} }
}); });
@@ -279,11 +268,9 @@ public abstract class WordBasedSegment extends Segment {
while (i < vertexes.length) { while (i < vertexes.length) {
if (vertexes[i].isEmpty()) { if (vertexes[i].isEmpty()) {
int j; int j;
for (j = i + 1; for (j = i + 1; j < vertexes.length - 1 && (vertexes[j].isEmpty()
j < vertexes.length - 1 || CharType.get(charArray[j - 1]) == 11); ++j) {
&& (vertexes[j].isEmpty() }
|| CharType.get(charArray[j - 1]) == 11);
++j) {}
wordNetStorage.add(i, Segment.quickAtomSegment(charArray, i - 1, j - 1)); wordNetStorage.add(i, Segment.quickAtomSegment(charArray, i - 1, j - 1));
i = j; i = j;
@@ -310,10 +297,8 @@ public abstract class WordBasedSegment extends Segment {
addTerms(termList, vertex, line - 1); addTerms(termList, vertex, line - 1);
termMain.offset = line - 1; termMain.offset = line - 1;
if (vertex.realWord.length() > 2) { if (vertex.realWord.length() > 2) {
label43: label43: for (int currentLine = line; currentLine < line
for (int currentLine = line; + vertex.realWord.length(); ++currentLine) {
currentLine < line + vertex.realWord.length();
++currentLine) {
Iterator iterator = wordNetAll.descendingIterator(currentLine); Iterator iterator = wordNetAll.descendingIterator(currentLine);
while (true) { while (true) {
@@ -327,8 +312,8 @@ public abstract class WordBasedSegment extends Segment {
&& smallVertex.realWord.length() < this.config.indexMode); && smallVertex.realWord.length() < this.config.indexMode);
if (smallVertex != vertex if (smallVertex != vertex
&& currentLine + smallVertex.realWord.length() && currentLine + smallVertex.realWord.length() <= line
<= line + vertex.realWord.length()) { + vertex.realWord.length()) {
listIterator.add(smallVertex); listIterator.add(smallVertex);
// Term termSub = convert(smallVertex); // Term termSub = convert(smallVertex);
// termSub.offset = currentLine - 1; // termSub.offset = currentLine - 1;
@@ -346,8 +331,8 @@ public abstract class WordBasedSegment extends Segment {
} }
protected static void speechTagging(List<Vertex> vertexList) { protected static void speechTagging(List<Vertex> vertexList) {
Viterbi.compute( Viterbi.compute(vertexList,
vertexList, CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary); CoreDictionaryTransformMatrixDictionary.transformMatrixDictionary);
} }
protected void addTerms(List<Term> terms, Vertex vertex, int offset) { protected void addTerms(List<Term> terms, Vertex vertex, int offset) {

View File

@@ -42,19 +42,13 @@ public class Term {
} }
// todo opt // todo opt
/* /*
String wordOri = word.toLowerCase(); * String wordOri = word.toLowerCase(); CoreDictionary.Attribute attribute =
CoreDictionary.Attribute attribute = getDynamicCustomDictionary().get(wordOri); * getDynamicCustomDictionary().get(wordOri); if (attribute == null) { attribute =
if (attribute == null) { * CoreDictionary.get(wordOri); if (attribute == null) { attribute =
attribute = CoreDictionary.get(wordOri); * CustomDictionary.get(wordOri); } } if (attribute != null && nature != null &&
if (attribute == null) { * attribute.hasNature(nature)) { return attribute.getNatureFrequency(nature); } return
attribute = CustomDictionary.get(wordOri); * attribute == null ? 0 : attribute.totalFrequency;
} */
}
if (attribute != null && nature != null && attribute.hasNature(nature)) {
return attribute.getNatureFrequency(nature);
}
return attribute == null ? 0 : attribute.totalFrequency;
*/
return 0; return 0;
} }

View File

@@ -51,19 +51,18 @@ public class Configuration {
public static SqlValidator.Config getValidatorConfig(EngineType engineType) { public static SqlValidator.Config getValidatorConfig(EngineType engineType) {
SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType); SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType);
return SqlValidator.Config.DEFAULT return SqlValidator.Config.DEFAULT.withConformance(sqlDialect.getConformance())
.withConformance(sqlDialect.getConformance())
.withDefaultNullCollation(config.defaultNullCollation()) .withDefaultNullCollation(config.defaultNullCollation())
.withLenientOperatorLookup(true); .withLenientOperatorLookup(true);
} }
static { static {
configProperties.put( configProperties.put(CalciteConnectionProperty.CASE_SENSITIVE.camelName(),
CalciteConnectionProperty.CASE_SENSITIVE.camelName(), Boolean.TRUE.toString()); Boolean.TRUE.toString());
configProperties.put( configProperties.put(CalciteConnectionProperty.UNQUOTED_CASING.camelName(),
CalciteConnectionProperty.UNQUOTED_CASING.camelName(), Casing.UNCHANGED.toString()); Casing.UNCHANGED.toString());
configProperties.put( configProperties.put(CalciteConnectionProperty.QUOTED_CASING.camelName(),
CalciteConnectionProperty.QUOTED_CASING.camelName(), Casing.TO_LOWER.toString()); Casing.TO_LOWER.toString());
} }
public static SqlParser.Config getParserConfig(EngineType engineType) { public static SqlParser.Config getParserConfig(EngineType engineType) {
@@ -76,15 +75,10 @@ public class Configuration {
parserConfig.setQuotedCasing(config.quotedCasing()); parserConfig.setQuotedCasing(config.quotedCasing());
parserConfig.setConformance(config.conformance()); parserConfig.setConformance(config.conformance());
parserConfig.setLex(Lex.BIG_QUERY); parserConfig.setLex(Lex.BIG_QUERY);
parserConfig parserConfig.setParserFactory(SqlParserImpl.FACTORY).setCaseSensitive(false)
.setParserFactory(SqlParserImpl.FACTORY) .setIdentifierMaxLength(Integer.MAX_VALUE).setQuoting(Quoting.BACK_TICK)
.setCaseSensitive(false) .setQuoting(Quoting.SINGLE_QUOTE).setQuotedCasing(Casing.TO_UPPER)
.setIdentifierMaxLength(Integer.MAX_VALUE) .setUnquotedCasing(Casing.TO_UPPER).setConformance(sqlDialect.getConformance())
.setQuoting(Quoting.BACK_TICK)
.setQuoting(Quoting.SINGLE_QUOTE)
.setQuotedCasing(Casing.TO_UPPER)
.setUnquotedCasing(Casing.TO_UPPER)
.setConformance(sqlDialect.getConformance())
.setLex(Lex.BIG_QUERY); .setLex(Lex.BIG_QUERY);
parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED); parserConfig = parserConfig.setQuotedCasing(Casing.UNCHANGED);
parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED); parserConfig = parserConfig.setUnquotedCasing(Casing.UNCHANGED);
@@ -96,61 +90,39 @@ public class Configuration {
tables.add(SqlStdOperatorTable.instance()); tables.add(SqlStdOperatorTable.instance());
SqlOperatorTable operatorTable = new ChainedSqlOperatorTable(tables); SqlOperatorTable operatorTable = new ChainedSqlOperatorTable(tables);
// operatorTable. // operatorTable.
Prepare.CatalogReader catalogReader = Prepare.CatalogReader catalogReader = new CalciteCatalogReader(rootSchema,
new CalciteCatalogReader( Collections.singletonList(rootSchema.getName()), typeFactory, config);
rootSchema, return SqlValidatorUtil.newValidator(operatorTable, catalogReader, typeFactory,
Collections.singletonList(rootSchema.getName()),
typeFactory,
config);
return SqlValidatorUtil.newValidator(
operatorTable,
catalogReader,
typeFactory,
Configuration.getValidatorConfig(engineType)); Configuration.getValidatorConfig(engineType));
} }
public static SqlValidatorWithHints getSqlValidatorWithHints( public static SqlValidatorWithHints getSqlValidatorWithHints(CalciteSchema rootSchema,
CalciteSchema rootSchema, EngineType engineTyp) { EngineType engineTyp) {
return new SqlAdvisorValidator( return new SqlAdvisorValidator(SqlStdOperatorTable.instance(),
SqlStdOperatorTable.instance(), new CalciteCatalogReader(rootSchema,
new CalciteCatalogReader( Collections.singletonList(rootSchema.getName()), typeFactory, config),
rootSchema, typeFactory, SqlValidator.Config.DEFAULT);
Collections.singletonList(rootSchema.getName()),
typeFactory,
config),
typeFactory,
SqlValidator.Config.DEFAULT);
} }
public static SqlToRelConverter.Config getConverterConfig() { public static SqlToRelConverter.Config getConverterConfig() {
HintStrategyTable strategies = HintStrategyTable.builder().build(); HintStrategyTable strategies = HintStrategyTable.builder().build();
return SqlToRelConverter.config() return SqlToRelConverter.config().withHintStrategyTable(strategies)
.withHintStrategyTable(strategies) .withTrimUnusedFields(true).withExpand(true)
.withTrimUnusedFields(true)
.withExpand(true)
.addRelBuilderConfigTransform(c -> c.withSimplify(false)); .addRelBuilderConfigTransform(c -> c.withSimplify(false));
} }
public static SqlToRelConverter getSqlToRelConverter( public static SqlToRelConverter getSqlToRelConverter(SqlValidatorScope scope,
SqlValidatorScope scope, SqlValidator sqlValidator, RelOptPlanner relOptPlanner, EngineType engineType) {
SqlValidator sqlValidator,
RelOptPlanner relOptPlanner,
EngineType engineType) {
RexBuilder rexBuilder = new RexBuilder(typeFactory); RexBuilder rexBuilder = new RexBuilder(typeFactory);
RelOptCluster cluster = RelOptCluster.create(relOptPlanner, rexBuilder); RelOptCluster cluster = RelOptCluster.create(relOptPlanner, rexBuilder);
FrameworkConfig fromworkConfig = FrameworkConfig fromworkConfig =
Frameworks.newConfigBuilder() Frameworks.newConfigBuilder().parserConfig(getParserConfig(engineType))
.parserConfig(getParserConfig(engineType))
.defaultSchema( .defaultSchema(
scope.getValidator().getCatalogReader().getRootSchema().plus()) scope.getValidator().getCatalogReader().getRootSchema().plus())
.build(); .build();
return new SqlToRelConverter( return new SqlToRelConverter(new ViewExpanderImpl(), sqlValidator,
new ViewExpanderImpl(), (CatalogReader) scope.getValidator().getCatalogReader(), cluster,
sqlValidator, fromworkConfig.getConvertletTable(), getConverterConfig());
(CatalogReader) scope.getValidator().getCatalogReader(),
cluster,
fromworkConfig.getConvertletTable(),
getConverterConfig());
} }
public static SqlAdvisor getSqlAdvisor(SqlValidatorWithHints validator, EngineType engineType) { public static SqlAdvisor getSqlAdvisor(SqlValidatorWithHints validator, EngineType engineType) {
@@ -159,15 +131,10 @@ public class Configuration {
public static SqlWriterConfig getSqlWriterConfig(EngineType engineType) { public static SqlWriterConfig getSqlWriterConfig(EngineType engineType) {
SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType); SemanticSqlDialect sqlDialect = SqlDialectFactory.getSqlDialect(engineType);
SqlWriterConfig config = SqlWriterConfig config = SqlPrettyWriter.config().withDialect(sqlDialect)
SqlPrettyWriter.config() .withKeywordsLowerCase(false).withClauseEndsLine(true)
.withDialect(sqlDialect) .withAlwaysUseParentheses(false).withSelectListItemsOnSeparateLines(false)
.withKeywordsLowerCase(false) .withUpdateSetListNewline(false).withIndentation(0);
.withClauseEndsLine(true)
.withAlwaysUseParentheses(false)
.withSelectListItemsOnSeparateLines(false)
.withUpdateSetListNewline(false)
.withIndentation(0);
if (EngineType.MYSQL.equals(engineType)) { if (EngineType.MYSQL.equals(engineType)) {
// no backticks around function name // no backticks around function name
config = config.withQuoteAllIdentifiers(false); config = config.withQuoteAllIdentifiers(false);

View File

@@ -17,8 +17,8 @@ public class SemanticSqlDialect extends SqlDialect {
super(context); super(context);
} }
public static void unparseFetchUsingAnsi( public static void unparseFetchUsingAnsi(SqlWriter writer, @Nullable SqlNode offset,
SqlWriter writer, @Nullable SqlNode offset, @Nullable SqlNode fetch) { @Nullable SqlNode fetch) {
Preconditions.checkArgument(fetch != null || offset != null); Preconditions.checkArgument(fetch != null || offset != null);
SqlWriter.Frame fetchFrame; SqlWriter.Frame fetchFrame;
writer.newlineAndIndent(); writer.newlineAndIndent();
@@ -74,11 +74,11 @@ public class SemanticSqlDialect extends SqlDialect {
return true; return true;
} }
public void unparseSqlIntervalLiteral( public void unparseSqlIntervalLiteral(SqlWriter writer, SqlIntervalLiteral literal,
SqlWriter writer, SqlIntervalLiteral literal, int leftPrec, int rightPrec) {} int leftPrec, int rightPrec) {}
public void unparseOffsetFetch( public void unparseOffsetFetch(SqlWriter writer, @Nullable SqlNode offset,
SqlWriter writer, @Nullable SqlNode offset, @Nullable SqlNode fetch) { @Nullable SqlNode fetch) {
unparseFetchUsingAnsi(writer, offset, fetch); unparseFetchUsingAnsi(writer, offset, fetch);
} }
} }

View File

@@ -13,22 +13,14 @@ import java.util.Objects;
public class SqlDialectFactory { public class SqlDialectFactory {
public static final Context DEFAULT_CONTEXT = public static final Context DEFAULT_CONTEXT =
SqlDialect.EMPTY_CONTEXT SqlDialect.EMPTY_CONTEXT.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withDatabaseProduct(DatabaseProduct.BIG_QUERY) .withLiteralQuoteString("'").withLiteralEscapedQuoteString("''")
.withLiteralQuoteString("'") .withIdentifierQuoteString("`").withUnquotedCasing(Casing.UNCHANGED)
.withLiteralEscapedQuoteString("''") .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
.withIdentifierQuoteString("`") public static final Context POSTGRESQL_CONTEXT = SqlDialect.EMPTY_CONTEXT
.withUnquotedCasing(Casing.UNCHANGED) .withDatabaseProduct(DatabaseProduct.BIG_QUERY).withLiteralQuoteString("'")
.withQuotedCasing(Casing.UNCHANGED) .withLiteralEscapedQuoteString("''").withUnquotedCasing(Casing.UNCHANGED)
.withCaseSensitive(false); .withQuotedCasing(Casing.UNCHANGED).withCaseSensitive(false);
public static final Context POSTGRESQL_CONTEXT =
SqlDialect.EMPTY_CONTEXT
.withDatabaseProduct(DatabaseProduct.BIG_QUERY)
.withLiteralQuoteString("'")
.withLiteralEscapedQuoteString("''")
.withUnquotedCasing(Casing.UNCHANGED)
.withQuotedCasing(Casing.UNCHANGED)
.withCaseSensitive(false);
private static Map<EngineType, SemanticSqlDialect> sqlDialectMap; private static Map<EngineType, SemanticSqlDialect> sqlDialectMap;
static { static {

View File

@@ -20,12 +20,8 @@ import java.util.List;
@Slf4j @Slf4j
public class SqlMergeWithUtils { public class SqlMergeWithUtils {
public static String mergeWith( public static String mergeWith(EngineType engineType, String sql, List<String> parentSqlList,
EngineType engineType, List<String> parentWithNameList) throws SqlParseException {
String sql,
List<String> parentSqlList,
List<String> parentWithNameList)
throws SqlParseException {
SqlParser.Config parserConfig = Configuration.getParserConfig(engineType); SqlParser.Config parserConfig = Configuration.getParserConfig(engineType);
// Parse the main SQL statement // Parse the main SQL statement
@@ -45,14 +41,12 @@ public class SqlMergeWithUtils {
SqlNode sqlNode2 = parser.parseQuery(); SqlNode sqlNode2 = parser.parseQuery();
// Create a new WITH item for parentWithName without quotes // Create a new WITH item for parentWithName without quotes
SqlWithItem withItem = SqlWithItem withItem = new SqlWithItem(SqlParserPos.ZERO,
new SqlWithItem( new SqlIdentifier(parentWithName, SqlParserPos.ZERO), // false
SqlParserPos.ZERO, // to
new SqlIdentifier( // avoid
parentWithName, SqlParserPos.ZERO), // false to avoid quotes // quotes
null, null, sqlNode2, SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
sqlNode2,
SqlLiteral.createBoolean(false, SqlParserPos.ZERO));
// Add the new WITH item to the list // Add the new WITH item to the list
withItemList.add(withItem); withItemList.add(withItem);
@@ -66,11 +60,8 @@ public class SqlMergeWithUtils {
} }
// Create a new SqlWith node // Create a new SqlWith node
SqlWith finalSqlNode = SqlWith finalSqlNode = new SqlWith(SqlParserPos.ZERO,
new SqlWith( new SqlNodeList(withItemList, SqlParserPos.ZERO), sqlNode1);
SqlParserPos.ZERO,
new SqlNodeList(withItemList, SqlParserPos.ZERO),
sqlNode1);
// Custom SqlPrettyWriter configuration to avoid quoting identifiers // Custom SqlPrettyWriter configuration to avoid quoting identifiers
SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType); SqlWriterConfig config = Configuration.getSqlWriterConfig(engineType);
// Pretty print the final SQL // Pretty print the final SQL

View File

@@ -45,10 +45,8 @@ public class SqlParseUtils {
sqlParserInfo.setAllFields( sqlParserInfo.setAllFields(
sqlParserInfo.getAllFields().stream().distinct().collect(Collectors.toList())); sqlParserInfo.getAllFields().stream().distinct().collect(Collectors.toList()));
sqlParserInfo.setSelectFields( sqlParserInfo.setSelectFields(sqlParserInfo.getSelectFields().stream().distinct()
sqlParserInfo.getSelectFields().stream() .collect(Collectors.toList()));
.distinct()
.collect(Collectors.toList()));
return sqlParserInfo; return sqlParserInfo;
} catch (SqlParseException e) { } catch (SqlParseException e) {
@@ -108,13 +106,10 @@ public class SqlParseUtils {
SqlSelect sqlSelect = (SqlSelect) select; SqlSelect sqlSelect = (SqlSelect) select;
SqlNodeList selectList = sqlSelect.getSelectList(); SqlNodeList selectList = sqlSelect.getSelectList();
selectList selectList.getList().forEach(list -> {
.getList() Set<String> selectFields = handlerField(list);
.forEach( sqlParserInfo.getSelectFields().addAll(selectFields);
list -> { });
Set<String> selectFields = handlerField(list);
sqlParserInfo.getSelectFields().addAll(selectFields);
});
String tableName = handlerFrom(sqlSelect.getFrom()); String tableName = handlerFrom(sqlSelect.getFrom());
sqlParserInfo.setTableName(tableName); sqlParserInfo.setTableName(tableName);
@@ -129,14 +124,10 @@ public class SqlParseUtils {
results.addAll(formFields); results.addAll(formFields);
} }
sqlSelect sqlSelect.getSelectList().getList().forEach(list -> {
.getSelectList() Set<String> selectFields = handlerField(list);
.getList() results.addAll(selectFields);
.forEach( });
list -> {
Set<String> selectFields = handlerField(list);
results.addAll(selectFields);
});
if (sqlSelect.hasWhere()) { if (sqlSelect.hasWhere()) {
Set<String> whereFields = handlerField(sqlSelect.getWhere()); Set<String> whereFields = handlerField(sqlSelect.getWhere());
@@ -148,11 +139,10 @@ public class SqlParseUtils {
} }
SqlNodeList group = sqlSelect.getGroup(); SqlNodeList group = sqlSelect.getGroup();
if (group != null) { if (group != null) {
group.forEach( group.forEach(groupField -> {
groupField -> { Set<String> groupByFields = handlerField(groupField);
Set<String> groupByFields = handlerField(groupField); results.addAll(groupByFields);
results.addAll(groupByFields); });
});
} }
return results; return results;
} }
@@ -213,12 +203,9 @@ public class SqlParseUtils {
} }
} }
if (field instanceof SqlNodeList) { if (field instanceof SqlNodeList) {
((SqlNodeList) field) ((SqlNodeList) field).getList().forEach(node -> {
.getList() fields.addAll(handlerField(node));
.forEach( });
node -> {
fields.addAll(handlerField(node));
});
} }
break; break;
} }
@@ -243,12 +230,9 @@ public class SqlParseUtils {
SqlIdentifier sqlIdentifier = (SqlIdentifier) operandList.get(0); SqlIdentifier sqlIdentifier = (SqlIdentifier) operandList.get(0);
String simple = sqlIdentifier.getSimple(); String simple = sqlIdentifier.getSimple();
SqlBasicCall aliasedNode = SqlBasicCall aliasedNode =
new SqlBasicCall( new SqlBasicCall(SqlStdOperatorTable.AS,
SqlStdOperatorTable.AS, new SqlNode[] {sqlBasicCall, new SqlIdentifier(
new SqlNode[] { simple.toLowerCase(), SqlParserPos.ZERO)},
sqlBasicCall,
new SqlIdentifier(simple.toLowerCase(), SqlParserPos.ZERO)
},
SqlParserPos.ZERO); SqlParserPos.ZERO);
selectList.set(selectList.indexOf(node), aliasedNode); selectList.set(selectList.indexOf(node), aliasedNode);
} }

View File

@@ -11,10 +11,7 @@ public class ViewExpanderImpl implements RelOptTable.ViewExpander {
public ViewExpanderImpl() {} public ViewExpanderImpl() {}
@Override @Override
public RelRoot expandView( public RelRoot expandView(RelDataType rowType, String queryString, List<String> schemaPath,
RelDataType rowType,
String queryString,
List<String> schemaPath,
List<String> dataSetPath) { List<String> dataSetPath) {
return null; return null;
} }

View File

@@ -20,98 +20,37 @@ import java.util.List;
@Slf4j @Slf4j
public class ChatModelParameterConfig extends ParameterConfig { public class ChatModelParameterConfig extends ParameterConfig {
public static final Parameter CHAT_MODEL_PROVIDER = public static final Parameter CHAT_MODEL_PROVIDER = new Parameter("s2.chat.model.provider",
new Parameter( OpenAiModelFactory.PROVIDER, "接口协议", "", "list", "对话模型配置", getCandidateValues());
"s2.chat.model.provider",
OpenAiModelFactory.PROVIDER,
"接口协议",
"",
"list",
"对话模型配置",
getCandidateValues());
public static final Parameter CHAT_MODEL_BASE_URL = public static final Parameter CHAT_MODEL_BASE_URL =
new Parameter( new Parameter("s2.chat.model.base.url", OpenAiModelFactory.DEFAULT_BASE_URL, "BaseUrl",
"s2.chat.model.base.url", "", "string", "对话模型配置", null, getBaseUrlDependency());
OpenAiModelFactory.DEFAULT_BASE_URL, public static final Parameter CHAT_MODEL_ENDPOINT = new Parameter("s2.chat.model.endpoint",
"BaseUrl", "llama_2_70b", "Endpoint", "", "string", "对话模型配置", null, getEndpointDependency());
"", public static final Parameter CHAT_MODEL_API_KEY = new Parameter("s2.chat.model.api.key", DEMO,
"string", "ApiKey", "", "password", "对话模型配置", null, getApiKeyDependency());
"对话模型配置", public static final Parameter CHAT_MODEL_SECRET_KEY = new Parameter("s2.chat.model.secretKey",
null, "demo", "SecretKey", "", "password", "对话模型配置", null, getSecretKeyDependency());
getBaseUrlDependency());
public static final Parameter CHAT_MODEL_ENDPOINT =
new Parameter(
"s2.chat.model.endpoint",
"llama_2_70b",
"Endpoint",
"",
"string",
"对话模型配置",
null,
getEndpointDependency());
public static final Parameter CHAT_MODEL_API_KEY =
new Parameter(
"s2.chat.model.api.key",
DEMO,
"ApiKey",
"",
"password",
"对话模型配置",
null,
getApiKeyDependency());
public static final Parameter CHAT_MODEL_SECRET_KEY =
new Parameter(
"s2.chat.model.secretKey",
"demo",
"SecretKey",
"",
"password",
"对话模型配置",
null,
getSecretKeyDependency());
public static final Parameter CHAT_MODEL_NAME = public static final Parameter CHAT_MODEL_NAME = new Parameter("s2.chat.model.name",
new Parameter( "gpt-4o-mini", "ModelName", "", "string", "对话模型配置", null, getModelNameDependency());
"s2.chat.model.name",
"gpt-4o-mini",
"ModelName",
"",
"string",
"对话模型配置",
null,
getModelNameDependency());
public static final Parameter CHAT_MODEL_ENABLE_SEARCH = public static final Parameter CHAT_MODEL_ENABLE_SEARCH =
new Parameter( new Parameter("s2.chat.model.enableSearch", "false", "是否启用搜索增强功能设为false表示不启用", "",
"s2.chat.model.enableSearch", "bool", "对话模型配置", null, getEnableSearchDependency());
"false",
"是否启用搜索增强功能设为false表示不启用",
"",
"bool",
"对话模型配置",
null,
getEnableSearchDependency());
public static final Parameter CHAT_MODEL_TEMPERATURE = public static final Parameter CHAT_MODEL_TEMPERATURE = new Parameter(
new Parameter( "s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置");
"s2.chat.model.temperature", "0.0", "Temperature", "", "slider", "对话模型配置");
public static final Parameter CHAT_MODEL_TIMEOUT = public static final Parameter CHAT_MODEL_TIMEOUT =
new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置"); new Parameter("s2.chat.model.timeout", "60", "超时时间(秒)", "", "number", "对话模型配置");
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
return Lists.newArrayList( return Lists.newArrayList(CHAT_MODEL_PROVIDER, CHAT_MODEL_BASE_URL, CHAT_MODEL_ENDPOINT,
CHAT_MODEL_PROVIDER, CHAT_MODEL_API_KEY, CHAT_MODEL_SECRET_KEY, CHAT_MODEL_NAME,
CHAT_MODEL_BASE_URL, CHAT_MODEL_ENABLE_SEARCH, CHAT_MODEL_TEMPERATURE, CHAT_MODEL_TIMEOUT);
CHAT_MODEL_ENDPOINT,
CHAT_MODEL_API_KEY,
CHAT_MODEL_SECRET_KEY,
CHAT_MODEL_NAME,
CHAT_MODEL_ENABLE_SEARCH,
CHAT_MODEL_TEMPERATURE,
CHAT_MODEL_TIMEOUT);
} }
public ChatModelConfig convert() { public ChatModelConfig convert() {
@@ -125,36 +64,24 @@ public class ChatModelParameterConfig extends ParameterConfig {
String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY); String secretKey = getParameterValue(CHAT_MODEL_SECRET_KEY);
String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH); String enableSearch = getParameterValue(CHAT_MODEL_ENABLE_SEARCH);
return ChatModelConfig.builder() return ChatModelConfig.builder().provider(chatModelProvider).baseUrl(chatModelBaseUrl)
.provider(chatModelProvider) .apiKey(chatModelApiKey).modelName(chatModelName)
.baseUrl(chatModelBaseUrl)
.apiKey(chatModelApiKey)
.modelName(chatModelName)
.enableSearch(Boolean.valueOf(enableSearch)) .enableSearch(Boolean.valueOf(enableSearch))
.temperature(Double.valueOf(chatModelTemperature)) .temperature(Double.valueOf(chatModelTemperature))
.timeOut(Long.valueOf(chatModelTimeout)) .timeOut(Long.valueOf(chatModelTimeout)).endpoint(endpoint).secretKey(secretKey)
.endpoint(endpoint)
.secretKey(secretKey)
.build(); .build();
} }
private static List<String> getCandidateValues() { private static List<String> getCandidateValues() {
return Lists.newArrayList( return Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER); AzureModelFactory.PROVIDER);
} }
private static List<Parameter.Dependency> getBaseUrlDependency() { private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency( return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
CHAT_MODEL_PROVIDER.getName(), ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL, OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL, QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_BASE_URL,
@@ -164,30 +91,18 @@ public class ChatModelParameterConfig extends ParameterConfig {
} }
private static List<Parameter.Dependency> getApiKeyDependency() { private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency( return getDependency(CHAT_MODEL_PROVIDER.getName(),
CHAT_MODEL_PROVIDER.getName(), Lists.newArrayList(OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
Lists.newArrayList( ZhipuModelFactory.PROVIDER, LocalAiModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER),
QianfanModelFactory.PROVIDER, ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER, DEMO, ZhipuModelFactory.PROVIDER, DEMO, LocalAiModelFactory.PROVIDER, DEMO,
LocalAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER, DEMO, DashscopeModelFactory.PROVIDER, DEMO));
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, DEMO,
QianfanModelFactory.PROVIDER, DEMO,
ZhipuModelFactory.PROVIDER, DEMO,
LocalAiModelFactory.PROVIDER, DEMO,
AzureModelFactory.PROVIDER, DEMO,
DashscopeModelFactory.PROVIDER, DEMO));
} }
private static List<Parameter.Dependency> getModelNameDependency() { private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency( return getDependency(CHAT_MODEL_PROVIDER.getName(), getCandidateValues(),
CHAT_MODEL_PROVIDER.getName(), ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
getCandidateValues(),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME, OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME, QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME, ZhipuModelFactory.PROVIDER, ZhipuModelFactory.DEFAULT_MODEL_NAME,
@@ -197,23 +112,19 @@ public class ChatModelParameterConfig extends ParameterConfig {
} }
private static List<Parameter.Dependency> getEndpointDependency() { private static List<Parameter.Dependency> getEndpointDependency() {
return getDependency( return getDependency(CHAT_MODEL_PROVIDER.getName(),
CHAT_MODEL_PROVIDER.getName(), Lists.newArrayList(QianfanModelFactory.PROVIDER), ImmutableMap
Lists.newArrayList(QianfanModelFactory.PROVIDER), .of(QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
ImmutableMap.of(
QianfanModelFactory.PROVIDER, QianfanModelFactory.DEFAULT_ENDPOINT));
} }
private static List<Parameter.Dependency> getEnableSearchDependency() { private static List<Parameter.Dependency> getEnableSearchDependency() {
return getDependency( return getDependency(CHAT_MODEL_PROVIDER.getName(),
CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(DashscopeModelFactory.PROVIDER), Lists.newArrayList(DashscopeModelFactory.PROVIDER),
ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false")); ImmutableMap.of(DashscopeModelFactory.PROVIDER, "false"));
} }
private static List<Parameter.Dependency> getSecretKeyDependency() { private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency( return getDependency(CHAT_MODEL_PROVIDER.getName(),
CHAT_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)); ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
} }

View File

@@ -22,89 +22,35 @@ import java.util.List;
@Slf4j @Slf4j
public class EmbeddingModelParameterConfig extends ParameterConfig { public class EmbeddingModelParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_MODEL_PROVIDER = public static final Parameter EMBEDDING_MODEL_PROVIDER =
new Parameter( new Parameter("s2.embedding.model.provider", InMemoryModelFactory.PROVIDER, "接口协议", "",
"s2.embedding.model.provider", "list", "向量模型配置", getCandidateValues());
InMemoryModelFactory.PROVIDER,
"接口协议",
"",
"list",
"向量模型配置",
getCandidateValues());
public static final Parameter EMBEDDING_MODEL_BASE_URL = public static final Parameter EMBEDDING_MODEL_BASE_URL =
new Parameter( new Parameter("s2.embedding.model.base.url", "", "BaseUrl", "", "string", "向量模型配置",
"s2.embedding.model.base.url", null, getBaseUrlDependency());
"",
"BaseUrl",
"",
"string",
"向量模型配置",
null,
getBaseUrlDependency());
public static final Parameter EMBEDDING_MODEL_API_KEY = public static final Parameter EMBEDDING_MODEL_API_KEY =
new Parameter( new Parameter("s2.embedding.model.api.key", "", "ApiKey", "", "password", "向量模型配置",
"s2.embedding.model.api.key", null, getApiKeyDependency());
"",
"ApiKey",
"",
"password",
"向量模型配置",
null,
getApiKeyDependency());
public static final Parameter EMBEDDING_MODEL_SECRET_KEY = public static final Parameter EMBEDDING_MODEL_SECRET_KEY =
new Parameter( new Parameter("s2.embedding.model.secretKey", "demo", "SecretKey", "", "password",
"s2.embedding.model.secretKey", "向量模型配置", null, getSecretKeyDependency());
"demo",
"SecretKey",
"",
"password",
"向量模型配置",
null,
getSecretKeyDependency());
public static final Parameter EMBEDDING_MODEL_NAME = public static final Parameter EMBEDDING_MODEL_NAME =
new Parameter( new Parameter("s2.embedding.model.name", EmbeddingModelConstant.BGE_SMALL_ZH,
"s2.embedding.model.name", "ModelName", "", "string", "向量模型配置", null, getModelNameDependency());
EmbeddingModelConstant.BGE_SMALL_ZH,
"ModelName",
"",
"string",
"向量模型配置",
null,
getModelNameDependency());
public static final Parameter EMBEDDING_MODEL_PATH = public static final Parameter EMBEDDING_MODEL_PATH = new Parameter("s2.embedding.model.path",
new Parameter( "", "模型路径", "", "string", "向量模型配置", null, getModelPathDependency());
"s2.embedding.model.path",
"",
"模型路径",
"",
"string",
"向量模型配置",
null,
getModelPathDependency());
public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH = public static final Parameter EMBEDDING_MODEL_VOCABULARY_PATH =
new Parameter( new Parameter("s2.embedding.model.vocabulary.path", "", "词汇表路径", "", "string", "向量模型配置",
"s2.embedding.model.vocabulary.path", null, getModelPathDependency());
"",
"词汇表路径",
"",
"string",
"向量模型配置",
null,
getModelPathDependency());
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
return Lists.newArrayList( return Lists.newArrayList(EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_BASE_URL,
EMBEDDING_MODEL_PROVIDER, EMBEDDING_MODEL_API_KEY, EMBEDDING_MODEL_SECRET_KEY, EMBEDDING_MODEL_NAME,
EMBEDDING_MODEL_BASE_URL, EMBEDDING_MODEL_PATH, EMBEDDING_MODEL_VOCABULARY_PATH);
EMBEDDING_MODEL_API_KEY,
EMBEDDING_MODEL_SECRET_KEY,
EMBEDDING_MODEL_NAME,
EMBEDDING_MODEL_PATH,
EMBEDDING_MODEL_VOCABULARY_PATH);
} }
public EmbeddingModelConfig convert() { public EmbeddingModelConfig convert() {
@@ -115,40 +61,24 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
String modelPath = getParameterValue(EMBEDDING_MODEL_PATH); String modelPath = getParameterValue(EMBEDDING_MODEL_PATH);
String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH); String vocabularyPath = getParameterValue(EMBEDDING_MODEL_VOCABULARY_PATH);
String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY); String secretKey = getParameterValue(EMBEDDING_MODEL_SECRET_KEY);
return EmbeddingModelConfig.builder() return EmbeddingModelConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey)
.provider(provider) .secretKey(secretKey).modelName(modelName).modelPath(modelPath)
.baseUrl(baseUrl) .vocabularyPath(vocabularyPath).build();
.apiKey(apiKey)
.secretKey(secretKey)
.modelName(modelName)
.modelPath(modelPath)
.vocabularyPath(vocabularyPath)
.build();
} }
private static ArrayList<String> getCandidateValues() { private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList( return Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
InMemoryModelFactory.PROVIDER, OllamaModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
AzureModelFactory.PROVIDER); AzureModelFactory.PROVIDER);
} }
private static List<Parameter.Dependency> getBaseUrlDependency() { private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency( return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
EMBEDDING_MODEL_PROVIDER.getName(), Lists.newArrayList(OpenAiModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
Lists.newArrayList( AzureModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER, QianfanModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER),
OllamaModelFactory.PROVIDER, ImmutableMap.of(OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER),
ImmutableMap.of(
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.DEFAULT_BASE_URL,
OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL, OllamaModelFactory.PROVIDER, OllamaModelFactory.DEFAULT_BASE_URL,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL, AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_BASE_URL,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL, DashscopeModelFactory.PROVIDER, DashscopeModelFactory.DEFAULT_BASE_URL,
@@ -157,63 +87,43 @@ public class EmbeddingModelParameterConfig extends ParameterConfig {
} }
private static List<Parameter.Dependency> getApiKeyDependency() { private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency( return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
EMBEDDING_MODEL_PROVIDER.getName(), Lists.newArrayList(OpenAiModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
Lists.newArrayList( DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER), ZhipuModelFactory.PROVIDER),
ImmutableMap.of( ImmutableMap.of(OpenAiModelFactory.PROVIDER, DEMO, AzureModelFactory.PROVIDER, DEMO,
OpenAiModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER, DEMO, QianfanModelFactory.PROVIDER, DEMO,
DEMO, ZhipuModelFactory.PROVIDER, DEMO));
AzureModelFactory.PROVIDER,
DEMO,
DashscopeModelFactory.PROVIDER,
DEMO,
QianfanModelFactory.PROVIDER,
DEMO,
ZhipuModelFactory.PROVIDER,
DEMO));
} }
private static List<Parameter.Dependency> getModelNameDependency() { private static List<Parameter.Dependency> getModelNameDependency() {
return getDependency( return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
EMBEDDING_MODEL_PROVIDER.getName(), Lists.newArrayList(InMemoryModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
Lists.newArrayList( OllamaModelFactory.PROVIDER, AzureModelFactory.PROVIDER,
InMemoryModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER), ZhipuModelFactory.PROVIDER),
ImmutableMap.of( ImmutableMap.of(InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
InMemoryModelFactory.PROVIDER, EmbeddingModelConstant.BGE_SMALL_ZH,
OpenAiModelFactory.PROVIDER, OpenAiModelFactory.PROVIDER,
OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, OpenAiModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
OllamaModelFactory.PROVIDER, OllamaModelFactory.PROVIDER,
OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, OllamaModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.PROVIDER,
AzureModelFactory.PROVIDER, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, AzureModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
DashscopeModelFactory.PROVIDER, DashscopeModelFactory.PROVIDER,
DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, DashscopeModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
QianfanModelFactory.PROVIDER, QianfanModelFactory.PROVIDER,
QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME, QianfanModelFactory.DEFAULT_EMBEDDING_MODEL_NAME,
ZhipuModelFactory.PROVIDER, ZhipuModelFactory.PROVIDER,
ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME)); ZhipuModelFactory.DEFAULT_EMBEDDING_MODEL_NAME));
} }
private static List<Parameter.Dependency> getModelPathDependency() { private static List<Parameter.Dependency> getModelPathDependency() {
return getDependency( return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(InMemoryModelFactory.PROVIDER), Lists.newArrayList(InMemoryModelFactory.PROVIDER),
ImmutableMap.of(InMemoryModelFactory.PROVIDER, "")); ImmutableMap.of(InMemoryModelFactory.PROVIDER, ""));
} }
private static List<Parameter.Dependency> getSecretKeyDependency() { private static List<Parameter.Dependency> getSecretKeyDependency() {
return getDependency( return getDependency(EMBEDDING_MODEL_PROVIDER.getName(),
EMBEDDING_MODEL_PROVIDER.getName(),
Lists.newArrayList(QianfanModelFactory.PROVIDER), Lists.newArrayList(QianfanModelFactory.PROVIDER),
ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO)); ImmutableMap.of(QianfanModelFactory.PROVIDER, DEMO));
} }

View File

@@ -15,83 +15,38 @@ import java.util.List;
@Service("EmbeddingStoreParameterConfig") @Service("EmbeddingStoreParameterConfig")
@Slf4j @Slf4j
public class EmbeddingStoreParameterConfig extends ParameterConfig { public class EmbeddingStoreParameterConfig extends ParameterConfig {
public static final Parameter EMBEDDING_STORE_PROVIDER = public static final Parameter EMBEDDING_STORE_PROVIDER = new Parameter(
new Parameter( "s2.embedding.store.provider", EmbeddingStoreType.IN_MEMORY.name(), "向量库类型",
"s2.embedding.store.provider", "目前支持三种类型IN_MEMORY、MILVUS、CHROMA", "list", "向量库配置", getCandidateValues());
EmbeddingStoreType.IN_MEMORY.name(),
"向量库类型",
"目前支持三种类型IN_MEMORY、MILVUS、CHROMA",
"list",
"向量库配置",
getCandidateValues());
public static final Parameter EMBEDDING_STORE_BASE_URL = public static final Parameter EMBEDDING_STORE_BASE_URL =
new Parameter( new Parameter("s2.embedding.store.base.url", "", "BaseUrl", "", "string", "向量库配置", null,
"s2.embedding.store.base.url",
"",
"BaseUrl",
"",
"string",
"向量库配置",
null,
getBaseUrlDependency()); getBaseUrlDependency());
public static final Parameter EMBEDDING_STORE_API_KEY = public static final Parameter EMBEDDING_STORE_API_KEY =
new Parameter( new Parameter("s2.embedding.store.api.key", "", "ApiKey", "", "password", "向量库配置", null,
"s2.embedding.store.api.key",
"",
"ApiKey",
"",
"password",
"向量库配置",
null,
getApiKeyDependency()); getApiKeyDependency());
public static final Parameter EMBEDDING_STORE_PERSIST_PATH = public static final Parameter EMBEDDING_STORE_PERSIST_PATH =
new Parameter( new Parameter("s2.embedding.store.persist.path", "", "持久化路径",
"s2.embedding.store.persist.path", "默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径", "string",
"", "向量库配置", null, getPathDependency());
"持久化路径",
"默认不持久化,如需持久化请填写持久化路径。" + "注意:如果变更了向量模型需删除该路径下已保存的文件或修改持久化路径",
"string",
"向量库配置",
null,
getPathDependency());
public static final Parameter EMBEDDING_STORE_TIMEOUT = public static final Parameter EMBEDDING_STORE_TIMEOUT =
new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置"); new Parameter("s2.embedding.store.timeout", "60", "超时时间(秒)", "", "number", "向量库配置");
public static final Parameter EMBEDDING_STORE_DIMENSION = public static final Parameter EMBEDDING_STORE_DIMENSION =
new Parameter( new Parameter("s2.embedding.store.dimension", "", "纬度", "", "number", "向量库配置", null,
"s2.embedding.store.dimension",
"",
"纬度",
"",
"number",
"向量库配置",
null,
getDimensionDependency()); getDimensionDependency());
public static final Parameter EMBEDDING_STORE_DATABASE_NAME = public static final Parameter EMBEDDING_STORE_DATABASE_NAME =
new Parameter( new Parameter("s2.embedding.store.databaseName", "", "DatabaseName", "", "string",
"s2.embedding.store.databaseName", "向量库配置", null, getDatabaseNameDependency());
"",
"DatabaseName",
"",
"string",
"向量库配置",
null,
getDatabaseNameDependency());
@Override @Override
public List<Parameter> getSysParameters() { public List<Parameter> getSysParameters() {
return Lists.newArrayList( return Lists.newArrayList(EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_BASE_URL,
EMBEDDING_STORE_PROVIDER, EMBEDDING_STORE_API_KEY, EMBEDDING_STORE_DATABASE_NAME,
EMBEDDING_STORE_BASE_URL, EMBEDDING_STORE_PERSIST_PATH, EMBEDDING_STORE_TIMEOUT, EMBEDDING_STORE_DIMENSION);
EMBEDDING_STORE_API_KEY,
EMBEDDING_STORE_DATABASE_NAME,
EMBEDDING_STORE_PERSIST_PATH,
EMBEDDING_STORE_TIMEOUT,
EMBEDDING_STORE_DIMENSION);
} }
public EmbeddingStoreConfig convert() { public EmbeddingStoreConfig convert() {
@@ -105,58 +60,44 @@ public class EmbeddingStoreParameterConfig extends ParameterConfig {
if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) { if (StringUtils.isNumeric(getParameterValue(EMBEDDING_STORE_DIMENSION))) {
dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION)); dimension = Integer.valueOf(getParameterValue(EMBEDDING_STORE_DIMENSION));
} }
return EmbeddingStoreConfig.builder() return EmbeddingStoreConfig.builder().provider(provider).baseUrl(baseUrl).apiKey(apiKey)
.provider(provider) .persistPath(persistPath).databaseName(databaseName).timeOut(Long.valueOf(timeOut))
.baseUrl(baseUrl) .dimension(dimension).build();
.apiKey(apiKey)
.persistPath(persistPath)
.databaseName(databaseName)
.timeOut(Long.valueOf(timeOut))
.dimension(dimension)
.build();
} }
private static ArrayList<String> getCandidateValues() { private static ArrayList<String> getCandidateValues() {
return Lists.newArrayList( return Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name(),
EmbeddingStoreType.IN_MEMORY.name(), EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name());
EmbeddingStoreType.MILVUS.name(),
EmbeddingStoreType.CHROMA.name());
} }
private static List<Parameter.Dependency> getBaseUrlDependency() { private static List<Parameter.Dependency> getBaseUrlDependency() {
return getDependency( return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
EMBEDDING_STORE_PROVIDER.getName(), Lists.newArrayList(EmbeddingStoreType.MILVUS.name(),
Lists.newArrayList( EmbeddingStoreType.CHROMA.name()),
EmbeddingStoreType.MILVUS.name(), EmbeddingStoreType.CHROMA.name()), ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
ImmutableMap.of(
EmbeddingStoreType.MILVUS.name(), "http://localhost:19530",
EmbeddingStoreType.CHROMA.name(), "http://localhost:8000")); EmbeddingStoreType.CHROMA.name(), "http://localhost:8000"));
} }
private static List<Parameter.Dependency> getApiKeyDependency() { private static List<Parameter.Dependency> getApiKeyDependency() {
return getDependency( return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO)); ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), DEMO));
} }
private static List<Parameter.Dependency> getPathDependency() { private static List<Parameter.Dependency> getPathDependency() {
return getDependency( return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()), Lists.newArrayList(EmbeddingStoreType.IN_MEMORY.name()),
ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), "")); ImmutableMap.of(EmbeddingStoreType.IN_MEMORY.name(), ""));
} }
private static List<Parameter.Dependency> getDimensionDependency() { private static List<Parameter.Dependency> getDimensionDependency() {
return getDependency( return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384")); ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "384"));
} }
private static List<Parameter.Dependency> getDatabaseNameDependency() { private static List<Parameter.Dependency> getDatabaseNameDependency() {
return getDependency( return getDependency(EMBEDDING_STORE_PROVIDER.getName(),
EMBEDDING_STORE_PROVIDER.getName(),
Lists.newArrayList(EmbeddingStoreType.MILVUS.name()), Lists.newArrayList(EmbeddingStoreType.MILVUS.name()),
ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), "")); ImmutableMap.of(EmbeddingStoreType.MILVUS.name(), ""));
} }

View File

@@ -15,9 +15,11 @@ import java.util.Map;
@Service @Service
public abstract class ParameterConfig { public abstract class ParameterConfig {
public static final String DEMO = "demo"; public static final String DEMO = "demo";
@Autowired private SystemConfigService sysConfigService; @Autowired
private SystemConfigService sysConfigService;
@Autowired private Environment environment; @Autowired
private Environment environment;
/** @return system parameters to be set with user interface */ /** @return system parameters to be set with user interface */
protected List<Parameter> getSysParameters() { protected List<Parameter> getSysParameters() {
@@ -46,10 +48,8 @@ public abstract class ParameterConfig {
return value; return value;
} }
protected static List<Parameter.Dependency> getDependency( protected static List<Parameter.Dependency> getDependency(String dependencyParameterName,
String dependencyParameterName, List<String> includesValue, Map<String, String> setDefaultValue) {
List<String> includesValue,
Map<String, String> setDefaultValue) {
Parameter.Dependency.Show show = new Parameter.Dependency.Show(); Parameter.Dependency.Show show = new Parameter.Dependency.Show();
show.setIncludesValue(includesValue); show.setIncludesValue(includesValue);

View File

@@ -38,11 +38,8 @@ public class SystemConfig {
if (StringUtils.isBlank(name)) { if (StringUtils.isBlank(name)) {
return ""; return "";
} }
Map<String, String> nameToValue = Map<String, String> nameToValue = getParameters().stream()
getParameters().stream() .collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (k1, k2) -> k1));
.collect(
Collectors.toMap(
Parameter::getName, Parameter::getValue, (k1, k2) -> k1));
return nameToValue.get(name); return nameToValue.get(name);
} }
@@ -69,15 +66,11 @@ public class SystemConfig {
if (CollectionUtils.isEmpty(parameters)) { if (CollectionUtils.isEmpty(parameters)) {
return defaultParameters; return defaultParameters;
} }
Map<String, String> parameterNameValueMap = Map<String, String> parameterNameValueMap = parameters.stream()
parameters.stream() .collect(Collectors.toMap(Parameter::getName, Parameter::getValue, (v1, v2) -> v2));
.collect(
Collectors.toMap(
Parameter::getName, Parameter::getValue, (v1, v2) -> v2));
for (Parameter parameter : defaultParameters) { for (Parameter parameter : defaultParameters) {
parameter.setValue( parameter.setValue(parameterNameValueMap.getOrDefault(parameter.getName(),
parameterNameValueMap.getOrDefault( parameter.getDefaultValue()));
parameter.getName(), parameter.getDefaultValue()));
} }
return defaultParameters; return defaultParameters;
} }

View File

@@ -14,8 +14,8 @@ import org.springframework.web.servlet.ModelAndView;
@Slf4j @Slf4j
public class LogInterceptor implements HandlerInterceptor { public class LogInterceptor implements HandlerInterceptor {
@Override @Override
public boolean preHandle( public boolean preHandle(HttpServletRequest request, HttpServletResponse response,
HttpServletRequest request, HttpServletResponse response, Object handler) { Object handler) {
// use previous traceId // use previous traceId
String traceId = request.getHeader(TraceIdUtil.TRACE_ID); String traceId = request.getHeader(TraceIdUtil.TRACE_ID);
if (StringUtils.isBlank(traceId)) { if (StringUtils.isBlank(traceId)) {
@@ -27,17 +27,12 @@ public class LogInterceptor implements HandlerInterceptor {
} }
@Override @Override
public void postHandle( public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler,
HttpServletRequest request, ModelAndView modelAndView) throws Exception {}
HttpServletResponse response,
Object handler,
ModelAndView modelAndView)
throws Exception {}
@Override @Override
public void afterCompletion( public void afterCompletion(HttpServletRequest request, HttpServletResponse response,
HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) Object handler, Exception ex) throws Exception {
throws Exception {
// remove after Completing // remove after Completing
TraceIdUtil.remove(); TraceIdUtil.remove();
} }

View File

@@ -5,13 +5,9 @@ import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
public enum AggregateEnum { public enum AggregateEnum {
MOST("最多", "max"), MOST("最多", "max"), HIGHEST("最高", "max"), MAXIMUN("最大", "max"), LEAST("最少",
HIGHEST("", "max"), "min"), SMALLEST("", "min"), LOWEST("最低", "min"), AVERAGE("平均", "avg");
MAXIMUN("最大", "max"),
LEAST("最少", "min"),
SMALLEST("最小", "min"),
LOWEST("最低", "min"),
AVERAGE("平均", "avg");
private String aggregateCh; private String aggregateCh;
private String aggregateEN; private String aggregateEN;
@@ -29,9 +25,7 @@ public enum AggregateEnum {
} }
public static Map<String, String> getAggregateEnum() { public static Map<String, String> getAggregateEnum() {
return Arrays.stream(AggregateEnum.values()) return Arrays.stream(AggregateEnum.values()).collect(
.collect( Collectors.toMap(AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
Collectors.toMap(
AggregateEnum::getAggregateCh, AggregateEnum::getAggregateEN));
} }
} }

View File

@@ -15,8 +15,8 @@ public class CustomExpressionDeParser extends ExpressionDeParser {
private boolean dealNull; private boolean dealNull;
private boolean dealNotNull; private boolean dealNotNull;
public CustomExpressionDeParser( public CustomExpressionDeParser(Set<String> removeFieldNames, boolean dealNull,
Set<String> removeFieldNames, boolean dealNull, boolean dealNotNull) { boolean dealNotNull) {
this.removeFieldNames = removeFieldNames; this.removeFieldNames = removeFieldNames;
this.dealNull = dealNull; this.dealNull = dealNull;
this.dealNotNull = dealNotNull; this.dealNotNull = dealNotNull;
@@ -45,12 +45,10 @@ public class CustomExpressionDeParser extends ExpressionDeParser {
Expression leftExpression = ((AndExpression) binaryExpression).getLeftExpression(); Expression leftExpression = ((AndExpression) binaryExpression).getLeftExpression();
Expression rightExpression = ((AndExpression) binaryExpression).getRightExpression(); Expression rightExpression = ((AndExpression) binaryExpression).getRightExpression();
boolean leftIsNull = boolean leftIsNull = leftExpression instanceof IsNullExpression
leftExpression instanceof IsNullExpression && shouldSkip((IsNullExpression) leftExpression);
&& shouldSkip((IsNullExpression) leftExpression); boolean rightIsNull = rightExpression instanceof IsNullExpression
boolean rightIsNull = && shouldSkip((IsNullExpression) rightExpression);
rightExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) rightExpression);
if (leftIsNull && rightIsNull) { if (leftIsNull && rightIsNull) {
// Skip both expressions // Skip both expressions

View File

@@ -13,8 +13,8 @@ import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
@Slf4j @Slf4j
public class DateFunctionHelper { public class DateFunctionHelper {
public static String getStartDateStr( public static String getStartDateStr(ComparisonOperator minorThanEquals,
ComparisonOperator minorThanEquals, ExpressionList<?> expressions) { ExpressionList<?> expressions) {
String unitValue = getUnit(expressions); String unitValue = getUnit(expressions);
String dateValue = getEndDateValue(expressions); String dateValue = getEndDateValue(expressions);
String dateStr = ""; String dateStr = "";

View File

@@ -23,9 +23,8 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter {
expr.getWhenExpression().accept(this); expr.getWhenExpression().accept(this);
if (expr.getThenExpression() instanceof Column) { if (expr.getThenExpression() instanceof Column) {
Column column = (Column) expr.getThenExpression(); Column column = (Column) expr.getThenExpression();
Expression expression = Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getExpression( QueryExpressionReplaceVisitor.getReplaceExpr(column, fieldExprMap));
QueryExpressionReplaceVisitor.getReplaceExpr(column, fieldExprMap));
if (Objects.nonNull(expression)) { if (Objects.nonNull(expression)) {
expr.setThenExpression(expression); expr.setThenExpression(expression);
} }
@@ -52,20 +51,16 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter {
} }
} }
if (left instanceof Column) { if (left instanceof Column) {
Expression expression = Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getExpression( QueryExpressionReplaceVisitor.getReplaceExpr((Column) left, fieldExprMap));
QueryExpressionReplaceVisitor.getReplaceExpr(
(Column) left, fieldExprMap));
if (Objects.nonNull(expression)) { if (Objects.nonNull(expression)) {
expr.setLeftExpression(expression); expr.setLeftExpression(expression);
leftVisited = true; leftVisited = true;
} }
} }
if (right instanceof Column) { if (right instanceof Column) {
Expression expression = Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getExpression( QueryExpressionReplaceVisitor.getReplaceExpr((Column) right, fieldExprMap));
QueryExpressionReplaceVisitor.getReplaceExpr(
(Column) right, fieldExprMap));
if (Objects.nonNull(expression)) { if (Objects.nonNull(expression)) {
expr.setRightExpression(expression); expr.setRightExpression(expression);
rightVisited = true; rightVisited = true;
@@ -81,9 +76,8 @@ public class ExpressionReplaceVisitor extends ExpressionVisitorAdapter {
private boolean visitFunction(Function function) { private boolean visitFunction(Function function) {
if (function.getParameters().getExpressions().get(0) instanceof Column) { if (function.getParameters().getExpressions().get(0) instanceof Column) {
Expression expression = Expression expression = QueryExpressionReplaceVisitor.getExpression(
QueryExpressionReplaceVisitor.getExpression( QueryExpressionReplaceVisitor.getReplaceExpr(function, fieldExprMap));
QueryExpressionReplaceVisitor.getReplaceExpr(function, fieldExprMap));
if (Objects.nonNull(expression)) { if (Objects.nonNull(expression)) {
ExpressionList<Expression> expressions = new ExpressionList<>(); ExpressionList<Expression> expressions = new ExpressionList<>();
expressions.add(expression); expressions.add(expression);

View File

@@ -130,8 +130,8 @@ public class FieldAndValueAcquireVisitor extends ExpressionVisitorAdapter {
Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList()); Arrays.stream(DatePeriodEnum.values()).collect(Collectors.toList());
DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName); DatePeriodEnum periodEnum = DatePeriodEnum.get(functionName);
if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) { if (Objects.nonNull(periodEnum) && collect.contains(periodEnum)) {
fieldExpression.setFieldValue( fieldExpression
getFieldValue(rightExpression) + periodEnum.getChName()); .setFieldValue(getFieldValue(rightExpression) + periodEnum.getChName());
return fieldExpression; return fieldExpression;
} else { } else {
// deal with aggregate function // deal with aggregate function

View File

@@ -31,8 +31,8 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
private boolean exactReplace; private boolean exactReplace;
private Map<String, Map<String, String>> filedNameToValueMap; private Map<String, Map<String, String>> filedNameToValueMap;
public FieldValueReplaceVisitor( public FieldValueReplaceVisitor(boolean exactReplace,
boolean exactReplace, Map<String, Map<String, String>> filedNameToValueMap) { Map<String, Map<String, String>> filedNameToValueMap) {
this.exactReplace = exactReplace; this.exactReplace = exactReplace;
this.filedNameToValueMap = filedNameToValueMap; this.filedNameToValueMap = filedNameToValueMap;
} }
@@ -67,24 +67,20 @@ public class FieldValueReplaceVisitor extends ExpressionVisitorAdapter {
ExpressionList rightItemsList = (ExpressionList) inExpression.getRightExpression(); ExpressionList rightItemsList = (ExpressionList) inExpression.getRightExpression();
List<Expression> expressions = rightItemsList.getExpressions(); List<Expression> expressions = rightItemsList.getExpressions();
List<String> values = new ArrayList<>(); List<String> values = new ArrayList<>();
expressions.stream() expressions.stream().forEach(o -> {
.forEach( if (o instanceof StringValue) {
o -> { values.add(((StringValue) o).getValue());
if (o instanceof StringValue) { }
values.add(((StringValue) o).getValue()); });
}
});
if (valueMap == null || CollectionUtils.isEmpty(values)) { if (valueMap == null || CollectionUtils.isEmpty(values)) {
return; return;
} }
List<Expression> newExpressions = new ArrayList<>(); List<Expression> newExpressions = new ArrayList<>();
values.stream() values.stream().forEach(o -> {
.forEach( String replaceValue = valueMap.getOrDefault(o, o);
o -> { StringValue stringValue = new StringValue(replaceValue);
String replaceValue = valueMap.getOrDefault(o, o); newExpressions.add(stringValue);
StringValue stringValue = new StringValue(replaceValue); });
newExpressions.add(stringValue);
});
rightItemsList.setExpressions(newExpressions); rightItemsList.setExpressions(newExpressions);
inExpression.setRightExpression(rightItemsList); inExpression.setRightExpression(rightItemsList);
} }

View File

@@ -34,11 +34,9 @@ public class FiledNameReplaceVisitor extends ExpressionVisitorAdapter {
Expression leftExpression = expr.getLeftExpression(); Expression leftExpression = expr.getLeftExpression();
Expression rightExpression = expr.getRightExpression(); Expression rightExpression = expr.getRightExpression();
if (!(rightExpression instanceof StringValue) if (!(rightExpression instanceof StringValue) || !(leftExpression instanceof Column)
|| !(leftExpression instanceof Column)
|| CollectionUtils.isEmpty(fieldValueToFieldNames) || CollectionUtils.isEmpty(fieldValueToFieldNames)
|| Objects.isNull(rightExpression) || Objects.isNull(rightExpression) || Objects.isNull(leftExpression)) {
|| Objects.isNull(leftExpression)) {
return; return;
} }

View File

@@ -21,8 +21,8 @@ public class FunctionAliasReplaceVisitor extends SelectItemVisitorAdapter {
// 2.alias's fieldName not equal. "sum(pv) as pv" cannot be replaced. // 2.alias's fieldName not equal. "sum(pv) as pv" cannot be replaced.
if (Objects.nonNull(selectExpressionItem.getAlias()) if (Objects.nonNull(selectExpressionItem.getAlias())
&& !selectExpressionItem.getAlias().getName().equalsIgnoreCase(columnName)) { && !selectExpressionItem.getAlias().getName().equalsIgnoreCase(columnName)) {
aliasToActualExpression.put( aliasToActualExpression.put(selectExpressionItem.getAlias().getName(),
selectExpressionItem.getAlias().getName(), function.toString()); function.toString());
selectExpressionItem.setAlias(null); selectExpressionItem.setAlias(null);
} }
} }

View File

@@ -16,8 +16,8 @@ public class FunctionNameReplaceVisitor extends ExpressionVisitorAdapter {
private Map<String, String> functionMap; private Map<String, String> functionMap;
private Map<String, UnaryOperator> functionCallMap; private Map<String, UnaryOperator> functionCallMap;
public FunctionNameReplaceVisitor( public FunctionNameReplaceVisitor(Map<String, String> functionMap,
Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) { Map<String, UnaryOperator> functionCallMap) {
this.functionMap = functionMap; this.functionMap = functionMap;
this.functionCallMap = functionCallMap; this.functionCallMap = functionCallMap;
} }

View File

@@ -19,8 +19,8 @@ public class GroupByFunctionReplaceVisitor implements GroupByVisitor {
private Map<String, String> functionMap; private Map<String, String> functionMap;
private Map<String, UnaryOperator> functionCallMap; private Map<String, UnaryOperator> functionCallMap;
public GroupByFunctionReplaceVisitor( public GroupByFunctionReplaceVisitor(Map<String, String> functionMap,
Map<String, String> functionMap, Map<String, UnaryOperator> functionCallMap) { Map<String, UnaryOperator> functionCallMap) {
this.functionMap = functionMap; this.functionMap = functionMap;
this.functionCallMap = functionCallMap; this.functionCallMap = functionCallMap;
} }

View File

@@ -53,11 +53,8 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
return expression.toString(); return expression.toString();
} }
private void replaceExpression( private void replaceExpression(List<Expression> groupByExpressions, int index,
List<Expression> groupByExpressions, Expression expression, String replaceColumn) {
int index,
Expression expression,
String replaceColumn) {
if (expression instanceof Column) { if (expression instanceof Column) {
groupByExpressions.set(index, new Column(replaceColumn)); groupByExpressions.set(index, new Column(replaceColumn));
} else if (expression instanceof Function) { } else if (expression instanceof Function) {
@@ -68,8 +65,7 @@ public class GroupByReplaceVisitor implements GroupByVisitor {
Function function = (Function) expression; Function function = (Function) expression;
if (function.getParameters().size() > 1) { if (function.getParameters().size() > 1) {
function.getParameters().stream() function.getParameters().stream().skip(1)
.skip(1)
.forEach(e -> newExpressionList.add((Function) e)); .forEach(e -> newExpressionList.add((Function) e));
} }
function.setParameters(newExpressionList); function.setParameters(newExpressionList);

View File

@@ -27,26 +27,14 @@ public class JsqlConstants {
public static final String IN_CONSTANT = " 1 in (1) "; public static final String IN_CONSTANT = " 1 in (1) ";
public static final String LIKE_CONSTANT = "1 like 1"; public static final String LIKE_CONSTANT = "1 like 1";
public static final String IN = "IN"; public static final String IN = "IN";
public static final Map<String, String> rightMap = public static final Map<String, String> rightMap = Stream.of(
Stream.of( new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>("<", "<"),
new AbstractMap.SimpleEntry<>("<=", "<="), new AbstractMap.SimpleEntry<>(">=", "<="), new AbstractMap.SimpleEntry<>(">", "<"),
new AbstractMap.SimpleEntry<>("<", "<"), new AbstractMap.SimpleEntry<>("=", "<="))
new AbstractMap.SimpleEntry<>(">=", "<="), .collect(toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
new AbstractMap.SimpleEntry<>(">", "<"), public static final Map<String, String> leftMap = Stream.of(
new AbstractMap.SimpleEntry<>("=", "<=")) new AbstractMap.SimpleEntry<>("<=", ">="), new AbstractMap.SimpleEntry<>("<", ">"),
.collect( new AbstractMap.SimpleEntry<>(">=", "<="), new AbstractMap.SimpleEntry<>(">", "<"),
toMap( new AbstractMap.SimpleEntry<>("=", ">="))
AbstractMap.SimpleEntry::getKey, .collect(toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue));
AbstractMap.SimpleEntry::getValue));
public static final Map<String, String> leftMap =
Stream.of(
new AbstractMap.SimpleEntry<>("<=", ">="),
new AbstractMap.SimpleEntry<>("<", ">"),
new AbstractMap.SimpleEntry<>(">=", "<="),
new AbstractMap.SimpleEntry<>(">", "<"),
new AbstractMap.SimpleEntry<>("=", ">="))
.collect(
toMap(
AbstractMap.SimpleEntry::getKey,
AbstractMap.SimpleEntry::getValue));
} }

View File

@@ -13,8 +13,8 @@ import java.util.stream.Collectors;
@Slf4j @Slf4j
public class ParseVisitorHelper { public class ParseVisitorHelper {
public void replaceColumn( public void replaceColumn(Column column, Map<String, String> fieldNameMap,
Column column, Map<String, String> fieldNameMap, boolean exactReplace) { boolean exactReplace) {
String columnName = StringUtil.replaceBackticks(column.getColumnName()); String columnName = StringUtil.replaceBackticks(column.getColumnName());
String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace); String replaceColumn = getReplaceValue(columnName, fieldNameMap, exactReplace);
if (StringUtils.isNotBlank(replaceColumn)) { if (StringUtils.isNotBlank(replaceColumn)) {
@@ -22,8 +22,8 @@ public class ParseVisitorHelper {
} }
} }
public String getReplaceValue( public String getReplaceValue(String beforeValue, Map<String, String> valueMap,
String beforeValue, Map<String, String> valueMap, boolean exactReplace) { boolean exactReplace) {
String value = valueMap.get(beforeValue); String value = valueMap.get(beforeValue);
if (StringUtils.isNotBlank(value)) { if (StringUtils.isNotBlank(value)) {
return value; return value;
@@ -31,19 +31,13 @@ public class ParseVisitorHelper {
if (exactReplace) { if (exactReplace) {
return null; return null;
} }
Optional<Entry<String, String>> first = Optional<Entry<String, String>> first = valueMap.entrySet().stream().sorted((k1, k2) -> {
valueMap.entrySet().stream() String k1Value = k1.getKey();
.sorted( String k2Value = k2.getKey();
(k1, k2) -> { Double k1Similarity = getSimilarity(beforeValue, k1Value);
String k1Value = k1.getKey(); Double k2Similarity = getSimilarity(beforeValue, k2Value);
String k2Value = k2.getKey(); return k2Similarity.compareTo(k1Similarity);
Double k1Similarity = getSimilarity(beforeValue, k1Value); }).collect(Collectors.toList()).stream().findFirst();
Double k2Similarity = getSimilarity(beforeValue, k2Value);
return k2Similarity.compareTo(k1Similarity);
})
.collect(Collectors.toList())
.stream()
.findFirst();
if (first.isPresent()) { if (first.isPresent()) {
return first.get().getValue(); return first.get().getValue();
@@ -68,16 +62,12 @@ public class ParseVisitorHelper {
char cj = word2.charAt(j - 1); char cj = word2.charAt(j - 1);
if (ci == cj) { if (ci == cj) {
dp[i][j] = dp[i - 1][j - 1]; dp[i][j] = dp[i - 1][j - 1];
} else if (i > 1 } else if (i > 1 && j > 1 && ci == word2.charAt(j - 2)
&& j > 1
&& ci == word2.charAt(j - 2)
&& cj == word1.charAt(i - 2)) { && cj == word1.charAt(i - 2)) {
dp[i][j] = 1 + Math.min(dp[i - 2][j - 2], Math.min(dp[i][j - 1], dp[i - 1][j])); dp[i][j] = 1 + Math.min(dp[i - 2][j - 2], Math.min(dp[i][j - 1], dp[i - 1][j]));
} else { } else {
dp[i][j] = dp[i][j] = Math.min(dp[i - 1][j - 1] + 1,
Math.min( Math.min(dp[i][j - 1] + 1, dp[i - 1][j] + 1));
dp[i - 1][j - 1] + 1,
Math.min(dp[i][j - 1] + 1, dp[i - 1][j] + 1));
} }
} }
} }

View File

@@ -43,32 +43,21 @@ public class SqlAddHelper {
} }
if (selectStatement instanceof PlainSelect) { if (selectStatement instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement; PlainSelect plainSelect = (PlainSelect) selectStatement;
fields.stream() fields.stream().filter(Objects::nonNull).forEach(field -> {
.filter(Objects::nonNull) SelectItem<Column> selectExpressionItem = new SelectItem(new Column(field));
.forEach( plainSelect.addSelectItems(selectExpressionItem);
field -> { });
SelectItem<Column> selectExpressionItem =
new SelectItem(new Column(field));
plainSelect.addSelectItems(selectExpressionItem);
});
} else if (selectStatement instanceof SetOperationList) { } else if (selectStatement instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement; SetOperationList setOperationList = (SetOperationList) selectStatement;
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList setOperationList.getSelects().forEach(subSelectBody -> {
.getSelects() PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
.forEach( fields.stream().forEach(field -> {
subSelectBody -> { SelectItem<Column> selectExpressionItem = new SelectItem(new Column(field));
PlainSelect subPlainSelect = (PlainSelect) subSelectBody; subPlainSelect.addSelectItems(selectExpressionItem);
fields.stream() });
.forEach( });
field -> {
SelectItem<Column> selectExpressionItem =
new SelectItem(new Column(field));
subPlainSelect.addSelectItems(
selectExpressionItem);
});
});
} }
} }
return selectStatement.toString(); return selectStatement.toString();
@@ -88,13 +77,10 @@ public class SqlAddHelper {
SetOperationList setOperationList = SetOperationList setOperationList =
(SetOperationList) selectStatement.getSetOperationList(); (SetOperationList) selectStatement.getSetOperationList();
if (!CollectionUtils.isEmpty(setOperationList.getSelects())) { if (!CollectionUtils.isEmpty(setOperationList.getSelects())) {
setOperationList setOperationList.getSelects().forEach(subSelectBody -> {
.getSelects() PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
.forEach( plainSelectList.add(subPlainSelect);
subSelectBody -> { });
PlainSelect subPlainSelect = (PlainSelect) subSelectBody;
plainSelectList.add(subPlainSelect);
});
} }
} }
@@ -238,18 +224,15 @@ public class SqlAddHelper {
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {
return sql; return sql;
} }
selectStatement.accept( selectStatement.accept(new SelectVisitorAdapter() {
new SelectVisitorAdapter() { @Override
@Override public void visit(PlainSelect plainSelect) {
public void visit(PlainSelect plainSelect) { addAggregateToSelectItems(plainSelect.getSelectItems(), fieldNameToAggregate);
addAggregateToSelectItems( addAggregateToOrderByItems(plainSelect.getOrderByElements(), fieldNameToAggregate);
plainSelect.getSelectItems(), fieldNameToAggregate); addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate);
addAggregateToOrderByItems( addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate);
plainSelect.getOrderByElements(), fieldNameToAggregate); }
addAggregateToGroupByItems(plainSelect.getGroupBy(), fieldNameToAggregate); });
addAggregateToWhereItems(plainSelect.getWhere(), fieldNameToAggregate);
}
});
return selectStatement.toString(); return selectStatement.toString();
} }
@@ -276,8 +259,8 @@ public class SqlAddHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
private static void addAggregateToSelectItems( private static void addAggregateToSelectItems(List<SelectItem<?>> selectItems,
List<SelectItem<?>> selectItems, Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
for (SelectItem selectItem : selectItems) { for (SelectItem selectItem : selectItems) {
Expression expression = selectItem.getExpression(); Expression expression = selectItem.getExpression();
Function function = Function function =
@@ -289,8 +272,8 @@ public class SqlAddHelper {
} }
} }
private static void addAggregateToOrderByItems( private static void addAggregateToOrderByItems(List<OrderByElement> orderByElements,
List<OrderByElement> orderByElements, Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (orderByElements == null) { if (orderByElements == null) {
return; return;
} }
@@ -305,8 +288,8 @@ public class SqlAddHelper {
} }
} }
private static void addAggregateToGroupByItems( private static void addAggregateToGroupByItems(GroupByElement groupByElement,
GroupByElement groupByElement, Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (groupByElement == null) { if (groupByElement == null) {
return; return;
} }
@@ -321,16 +304,16 @@ public class SqlAddHelper {
} }
} }
private static void addAggregateToWhereItems( private static void addAggregateToWhereItems(Expression whereExpression,
Expression whereExpression, Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (whereExpression == null) { if (whereExpression == null) {
return; return;
} }
modifyWhereExpression(whereExpression, fieldNameToAggregate); modifyWhereExpression(whereExpression, fieldNameToAggregate);
} }
private static void modifyWhereExpression( private static void modifyWhereExpression(Expression whereExpression,
Expression whereExpression, Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (SqlSelectHelper.isLogicExpression(whereExpression)) { if (SqlSelectHelper.isLogicExpression(whereExpression)) {
if (whereExpression instanceof AndExpression) { if (whereExpression instanceof AndExpression) {
AndExpression andExpression = (AndExpression) whereExpression; AndExpression andExpression = (AndExpression) whereExpression;
@@ -347,15 +330,15 @@ public class SqlAddHelper {
modifyWhereExpression(rightExpression, fieldNameToAggregate); modifyWhereExpression(rightExpression, fieldNameToAggregate);
} }
} else if (whereExpression instanceof Parenthesis) { } else if (whereExpression instanceof Parenthesis) {
modifyWhereExpression( modifyWhereExpression(((Parenthesis) whereExpression).getExpression(),
((Parenthesis) whereExpression).getExpression(), fieldNameToAggregate); fieldNameToAggregate);
} else { } else {
setAggToFunction(whereExpression, fieldNameToAggregate); setAggToFunction(whereExpression, fieldNameToAggregate);
} }
} }
private static void setAggToFunction( private static void setAggToFunction(Expression expression,
Expression expression, Map<String, String> fieldNameToAggregate) { Map<String, String> fieldNameToAggregate) {
if (!(expression instanceof ComparisonOperator)) { if (!(expression instanceof ComparisonOperator)) {
return; return;
} }
@@ -363,20 +346,16 @@ public class SqlAddHelper {
if (comparisonOperator.getRightExpression() instanceof Column) { if (comparisonOperator.getRightExpression() instanceof Column) {
String columnName = String columnName =
((Column) (comparisonOperator).getRightExpression()).getColumnName(); ((Column) (comparisonOperator).getRightExpression()).getColumnName();
Function function = Function function = SqlSelectFunctionHelper.getFunction(
SqlSelectFunctionHelper.getFunction( comparisonOperator.getRightExpression(), fieldNameToAggregate.get(columnName));
comparisonOperator.getRightExpression(),
fieldNameToAggregate.get(columnName));
if (Objects.nonNull(function)) { if (Objects.nonNull(function)) {
comparisonOperator.setRightExpression(function); comparisonOperator.setRightExpression(function);
} }
} }
if (comparisonOperator.getLeftExpression() instanceof Column) { if (comparisonOperator.getLeftExpression() instanceof Column) {
String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName(); String columnName = ((Column) (comparisonOperator).getLeftExpression()).getColumnName();
Function function = Function function = SqlSelectFunctionHelper.getFunction(
SqlSelectFunctionHelper.getFunction( comparisonOperator.getLeftExpression(), fieldNameToAggregate.get(columnName));
comparisonOperator.getLeftExpression(),
fieldNameToAggregate.get(columnName));
if (Objects.nonNull(function)) { if (Objects.nonNull(function)) {
comparisonOperator.setLeftExpression(function); comparisonOperator.setLeftExpression(function);
} }

View File

@@ -27,18 +27,17 @@ public class SqlAsHelper {
if (plainSelect instanceof Select) { if (plainSelect instanceof Select) {
Select select = plainSelect; Select select = plainSelect;
Select selectBody = select.getSelectBody(); Select selectBody = select.getSelectBody();
selectBody.accept( selectBody.accept(new SelectVisitorAdapter() {
new SelectVisitorAdapter() { @Override
@Override public void visit(PlainSelect plainSelect) {
public void visit(PlainSelect plainSelect) { extractAliasesFromSelect(plainSelect, aliases);
extractAliasesFromSelect(plainSelect, aliases); }
}
@Override @Override
public void visit(WithItem withItem) { public void visit(WithItem withItem) {
withItem.getSelectBody().accept(this); withItem.getSelectBody().accept(this);
} }
}); });
} }
} }
return new ArrayList<>(aliases); return new ArrayList<>(aliases);

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.common.jsqlparser; package com.tencent.supersonic.common.jsqlparser;
public enum SqlEditEnum { public enum SqlEditEnum {
NUMBER_FILTER, NUMBER_FILTER, DATEDIFF
DATEDIFF
} }

View File

@@ -67,15 +67,14 @@ public class SqlRemoveHelper {
} }
List<SelectItem<?>> selectItems = ((PlainSelect) selectStatement).getSelectItems(); List<SelectItem<?>> selectItems = ((PlainSelect) selectStatement).getSelectItems();
Set<String> fields = new HashSet<>(); Set<String> fields = new HashSet<>();
selectItems.removeIf( selectItems.removeIf(selectItem -> {
selectItem -> { String field = selectItem.getExpression().toString();
String field = selectItem.getExpression().toString(); if (fields.contains(field)) {
if (fields.contains(field)) { return true;
return true; }
} fields.add(field);
fields.add(field); return false;
return false; });
});
((PlainSelect) selectStatement).setSelectItems(selectItems); ((PlainSelect) selectStatement).setSelectItems(selectItems);
return selectStatement.toString(); return selectStatement.toString();
} }
@@ -85,18 +84,17 @@ public class SqlRemoveHelper {
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {
return sql; return sql;
} }
selectStatement.accept( selectStatement.accept(new SelectVisitorAdapter() {
new SelectVisitorAdapter() { @Override
@Override public void visit(PlainSelect plainSelect) {
public void visit(PlainSelect plainSelect) { removeWhereCondition(plainSelect.getWhere(), removeFieldNames);
removeWhereCondition(plainSelect.getWhere(), removeFieldNames); }
} });
});
return removeNumberFilter(selectStatement.toString()); return removeNumberFilter(selectStatement.toString());
} }
private static void removeWhereCondition( private static void removeWhereCondition(Expression whereExpression,
Expression whereExpression, Set<String> removeFieldNames) { Set<String> removeFieldNames) {
if (whereExpression == null) { if (whereExpression == null) {
return; return;
} }
@@ -121,8 +119,8 @@ public class SqlRemoveHelper {
return selectStatement.toString(); return selectStatement.toString();
} }
private static void removeWhereExpression( private static void removeWhereExpression(Expression whereExpression,
Expression whereExpression, Set<String> removeFieldNames) { Set<String> removeFieldNames) {
if (SqlSelectHelper.isLogicExpression(whereExpression)) { if (SqlSelectHelper.isLogicExpression(whereExpression)) {
BinaryExpression binaryExpression = (BinaryExpression) whereExpression; BinaryExpression binaryExpression = (BinaryExpression) whereExpression;
Expression leftExpression = binaryExpression.getLeftExpression(); Expression leftExpression = binaryExpression.getLeftExpression();
@@ -131,8 +129,8 @@ public class SqlRemoveHelper {
removeWhereExpression(leftExpression, removeFieldNames); removeWhereExpression(leftExpression, removeFieldNames);
removeWhereExpression(rightExpression, removeFieldNames); removeWhereExpression(rightExpression, removeFieldNames);
} else if (whereExpression instanceof Parenthesis) { } else if (whereExpression instanceof Parenthesis) {
removeWhereExpression( removeWhereExpression(((Parenthesis) whereExpression).getExpression(),
((Parenthesis) whereExpression).getExpression(), removeFieldNames); removeFieldNames);
} else { } else {
removeExpressionWithConstant(whereExpression, removeFieldNames); removeExpressionWithConstant(whereExpression, removeFieldNames);
} }
@@ -152,8 +150,8 @@ public class SqlRemoveHelper {
return constant; return constant;
} }
private static void removeExpressionWithConstant( private static void removeExpressionWithConstant(Expression expression,
Expression expression, Set<String> removeFieldNames) { Set<String> removeFieldNames) {
try { try {
if (expression instanceof ComparisonOperator) { if (expression instanceof ComparisonOperator) {
handleComparisonOperator((ComparisonOperator) expression, removeFieldNames); handleComparisonOperator((ComparisonOperator) expression, removeFieldNames);
@@ -167,13 +165,10 @@ public class SqlRemoveHelper {
} }
} }
private static void handleComparisonOperator( private static void handleComparisonOperator(ComparisonOperator comparisonOperator,
ComparisonOperator comparisonOperator, Set<String> removeFieldNames) Set<String> removeFieldNames) throws JSQLParserException {
throws JSQLParserException { String columnName = SqlSelectHelper.getColumnName(comparisonOperator.getLeftExpression(),
String columnName = comparisonOperator.getRightExpression());
SqlSelectHelper.getColumnName(
comparisonOperator.getLeftExpression(),
comparisonOperator.getRightExpression());
if (!removeFieldNames.contains(columnName)) { if (!removeFieldNames.contains(columnName)) {
return; return;
} }
@@ -185,9 +180,8 @@ public class SqlRemoveHelper {
private static void handleInExpression(InExpression inExpression, Set<String> removeFieldNames) private static void handleInExpression(InExpression inExpression, Set<String> removeFieldNames)
throws JSQLParserException { throws JSQLParserException {
String columnName = String columnName = SqlSelectHelper.getColumnName(inExpression.getLeftExpression(),
SqlSelectHelper.getColumnName( inExpression.getRightExpression());
inExpression.getLeftExpression(), inExpression.getRightExpression());
if (!removeFieldNames.contains(columnName)) { if (!removeFieldNames.contains(columnName)) {
return; return;
} }
@@ -196,12 +190,10 @@ public class SqlRemoveHelper {
updateInExpression(inExpression, constantExpression); updateInExpression(inExpression, constantExpression);
} }
private static void handleLikeExpression( private static void handleLikeExpression(LikeExpression likeExpression,
LikeExpression likeExpression, Set<String> removeFieldNames) Set<String> removeFieldNames) throws JSQLParserException {
throws JSQLParserException { String columnName = SqlSelectHelper.getColumnName(likeExpression.getLeftExpression(),
String columnName = likeExpression.getRightExpression());
SqlSelectHelper.getColumnName(
likeExpression.getLeftExpression(), likeExpression.getRightExpression());
if (!removeFieldNames.contains(columnName)) { if (!removeFieldNames.contains(columnName)) {
return; return;
} }
@@ -210,8 +202,8 @@ public class SqlRemoveHelper {
updateLikeExpression(likeExpression, constantExpression); updateLikeExpression(likeExpression, constantExpression);
} }
private static void updateComparisonOperator( private static void updateComparisonOperator(ComparisonOperator original,
ComparisonOperator original, ComparisonOperator constantExpression) { ComparisonOperator constantExpression) {
original.setLeftExpression(constantExpression.getLeftExpression()); original.setLeftExpression(constantExpression.getLeftExpression());
original.setRightExpression(constantExpression.getRightExpression()); original.setRightExpression(constantExpression.getRightExpression());
original.setASTNode(constantExpression.getASTNode()); original.setASTNode(constantExpression.getASTNode());
@@ -223,8 +215,8 @@ public class SqlRemoveHelper {
original.setASTNode(constantExpression.getASTNode()); original.setASTNode(constantExpression.getASTNode());
} }
private static void updateLikeExpression( private static void updateLikeExpression(LikeExpression original,
LikeExpression original, LikeExpression constantExpression) { LikeExpression constantExpression) {
original.setLeftExpression(constantExpression.getLeftExpression()); original.setLeftExpression(constantExpression.getLeftExpression());
original.setRightExpression(constantExpression.getRightExpression()); original.setRightExpression(constantExpression.getRightExpression());
} }
@@ -234,13 +226,12 @@ public class SqlRemoveHelper {
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {
return sql; return sql;
} }
selectStatement.accept( selectStatement.accept(new SelectVisitorAdapter() {
new SelectVisitorAdapter() { @Override
@Override public void visit(PlainSelect plainSelect) {
public void visit(PlainSelect plainSelect) { removeWhereCondition(plainSelect.getHaving(), removeFieldNames);
removeWhereCondition(plainSelect.getHaving(), removeFieldNames); }
} });
});
return removeNumberFilter(selectStatement.toString()); return removeNumberFilter(selectStatement.toString());
} }
@@ -254,16 +245,13 @@ public class SqlRemoveHelper {
return sql; return sql;
} }
ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList(); ExpressionList groupByExpressionList = groupByElement.getGroupByExpressionList();
groupByExpressionList groupByExpressionList.getExpressions().removeIf(expression -> {
.getExpressions() if (expression instanceof Column) {
.removeIf( Column column = (Column) expression;
expression -> { return fields.contains(column.getColumnName());
if (expression instanceof Column) { }
Column column = (Column) expression; return false;
return fields.contains(column.getColumnName()); });
}
return false;
});
if (CollectionUtils.isEmpty(groupByExpressionList.getExpressions())) { if (CollectionUtils.isEmpty(groupByExpressionList.getExpressions())) {
((PlainSelect) selectStatement).setGroupByElement(null); ((PlainSelect) selectStatement).setGroupByElement(null);
} }
@@ -279,15 +267,14 @@ public class SqlRemoveHelper {
Iterator<SelectItem<?>> iterator = selectItems.iterator(); Iterator<SelectItem<?>> iterator = selectItems.iterator();
while (iterator.hasNext()) { while (iterator.hasNext()) {
SelectItem selectItem = iterator.next(); SelectItem selectItem = iterator.next();
selectItem.accept( selectItem.accept(new SelectItemVisitorAdapter() {
new SelectItemVisitorAdapter() { @Override
@Override public void visit(SelectItem item) {
public void visit(SelectItem item) { if (fields.contains(item.getExpression().toString())) {
if (fields.contains(item.getExpression().toString())) { iterator.remove();
iterator.remove(); }
} }
} });
});
} }
if (selectItems.isEmpty()) { if (selectItems.isEmpty()) {
selectItems.add(new SelectItem(new AllColumns())); selectItems.add(new SelectItem(new AllColumns()));
@@ -345,17 +332,14 @@ public class SqlRemoveHelper {
} }
} }
private static Expression dealComparisonOperatorFilter( private static Expression dealComparisonOperatorFilter(Expression expression,
Expression expression, SqlEditEnum sqlEditEnum) { SqlEditEnum sqlEditEnum) {
if (Objects.isNull(expression)) { if (Objects.isNull(expression)) {
return null; return null;
} }
if (expression instanceof GreaterThanEquals if (expression instanceof GreaterThanEquals || expression instanceof GreaterThan
|| expression instanceof GreaterThan || expression instanceof MinorThan || expression instanceof MinorThanEquals
|| expression instanceof MinorThan || expression instanceof EqualsTo || expression instanceof NotEqualsTo) {
|| expression instanceof MinorThanEquals
|| expression instanceof EqualsTo
|| expression instanceof NotEqualsTo) {
return removeSingleFilter((ComparisonOperator) expression, sqlEditEnum); return removeSingleFilter((ComparisonOperator) expression, sqlEditEnum);
} else if (expression instanceof InExpression) { } else if (expression instanceof InExpression) {
InExpression inExpression = (InExpression) expression; InExpression inExpression = (InExpression) expression;
@@ -369,14 +353,14 @@ public class SqlRemoveHelper {
return expression; return expression;
} }
private static Expression removeSingleFilter( private static Expression removeSingleFilter(ComparisonOperator comparisonExpression,
ComparisonOperator comparisonExpression, SqlEditEnum sqlEditEnum) { SqlEditEnum sqlEditEnum) {
Expression leftExpression = comparisonExpression.getLeftExpression(); Expression leftExpression = comparisonExpression.getLeftExpression();
return recursionBase(leftExpression, comparisonExpression, sqlEditEnum); return recursionBase(leftExpression, comparisonExpression, sqlEditEnum);
} }
private static Expression recursionBase( private static Expression recursionBase(Expression leftExpression, Expression expression,
Expression leftExpression, Expression expression, SqlEditEnum sqlEditEnum) { SqlEditEnum sqlEditEnum) {
if (sqlEditEnum.equals(SqlEditEnum.NUMBER_FILTER)) { if (sqlEditEnum.equals(SqlEditEnum.NUMBER_FILTER)) {
return distinguishNumberFilter(leftExpression, expression); return distinguishNumberFilter(leftExpression, expression);
} }
@@ -386,8 +370,8 @@ public class SqlRemoveHelper {
return expression; return expression;
} }
private static Expression distinguishNumberFilter( private static Expression distinguishNumberFilter(Expression leftExpression,
Expression leftExpression, Expression expression) { Expression expression) {
if (leftExpression instanceof LongValue) { if (leftExpression instanceof LongValue) {
return null; return null;
} else { } else {
@@ -403,8 +387,8 @@ public class SqlRemoveHelper {
return removeIsNullOrNotNullInWhere(false, true, sql, removeFieldNames); return removeIsNullOrNotNullInWhere(false, true, sql, removeFieldNames);
} }
public static String removeIsNullOrNotNullInWhere( public static String removeIsNullOrNotNullInWhere(boolean dealNull, boolean dealNotNull,
boolean dealNull, boolean dealNotNull, String sql, Set<String> removeFieldNames) { String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql); Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) { if (!(selectStatement instanceof PlainSelect)) {
return sql; return sql;

Some files were not shown because too many files have changed in this diff Show More