diff --git a/assembly/bin/supersonic-build.sh b/assembly/bin/supersonic-build.sh old mode 100644 new mode 100755 diff --git a/assembly/bin/supersonic-daemon.sh b/assembly/bin/supersonic-daemon.sh old mode 100644 new mode 100755 diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java index 116629f19..3f582d2c5 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/constant/UserConstants.java @@ -12,6 +12,8 @@ public class UserConstants { public static final String TOKEN_USER_EMAIL = "token_user_email"; + public static final String TOKEN_IS_ADMIN = "token_is_admin"; + public static final String TOKEN_ALGORITHM = "HS512"; public static final String TOKEN_CREATE_TIME = "token_create_time"; diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java index 28241eb14..4cf2b526d 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/User.java @@ -18,17 +18,22 @@ public class User { private String email; - public static User get(Long id, String name, String displayName, String email) { - return new User(id, name, displayName, email); + private Integer isAdmin; + + public static User get(Long id, String name, String displayName, String email, Integer isAdmin) { + return new User(id, name, displayName, email, isAdmin); } public static User getFakeUser() { - return new User(1L, "admin", "admin", "admin@email"); + return new User(1L, "admin", "admin", "admin@email", 1); } public String getDisplayName() { return StringUtils.isBlank(displayName) ? name : displayName; } + public boolean isSuperAdmin() { + return isAdmin != null && isAdmin == 1; + } } diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java index c7384c1e5..36f77eae2 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authentication/pojo/UserWithPassword.java @@ -9,13 +9,14 @@ public class UserWithPassword extends User { private String password; - public UserWithPassword(Long id, String name, String displayName, String email, String password) { - super(id, name, displayName, email); + public UserWithPassword(Long id, String name, String displayName, String email, String password, Integer isAdmin) { + super(id, name, displayName, email, isAdmin); this.password = password; } - public static UserWithPassword get(Long id, String name, String displayName, String email, String password) { - return new UserWithPassword(id, name, displayName, email, password); + public static UserWithPassword get(Long id, String name, String displayName, + String email, String password, Integer isAdmin) { + return new UserWithPassword(id, name, displayName, email, password, isAdmin); } } diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java index e762ca9a3..9d5893343 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/adaptor/DefaultUserAdaptor.java @@ -71,7 +71,7 @@ public class DefaultUserAdaptor implements UserAdaptor { } if (userDO.getPassword().equals(userReq.getPassword())) { UserWithPassword user = UserWithPassword.get(userDO.getId(), userDO.getName(), userDO.getDisplayName(), - userDO.getEmail(), userDO.getPassword()); + userDO.getEmail(), userDO.getPassword(), userDO.getIsAdmin()); return userTokenUtils.generateToken(user); } throw new RuntimeException("password not correct, please try again"); diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java index 77b4ae9e7..af32a9aff 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDO.java @@ -1,99 +1,129 @@ package com.tencent.supersonic.auth.authentication.persistence.dataobject; public class UserDO { - /** - * + * */ private Long id; /** - * + * */ private String name; /** - * + * */ private String password; /** - * + * */ private String displayName; /** - * + * */ private String email; /** - * @return id + * + */ + private Integer isAdmin; + + /** + * + * @return id */ public Long getId() { return id; } /** - * @param id + * + * @param id */ public void setId(Long id) { this.id = id; } /** - * @return name + * + * @return name */ public String getName() { return name; } /** - * @param name + * + * @param name */ public void setName(String name) { this.name = name == null ? null : name.trim(); } /** - * @return password + * + * @return password */ public String getPassword() { return password; } /** - * @param password + * + * @param password */ public void setPassword(String password) { this.password = password == null ? null : password.trim(); } /** - * @return display_name + * + * @return display_name */ public String getDisplayName() { return displayName; } /** - * @param displayName + * + * @param displayName */ public void setDisplayName(String displayName) { this.displayName = displayName == null ? null : displayName.trim(); } /** - * @return email + * + * @return email */ public String getEmail() { return email; } /** - * @param email + * + * @param email */ public void setEmail(String email) { this.email = email == null ? null : email.trim(); } + + /** + * + * @return is_admin + */ + public Integer getIsAdmin() { + return isAdmin; + } + + /** + * + * @param isAdmin + */ + public void setIsAdmin(Integer isAdmin) { + this.isAdmin = isAdmin; + } } \ No newline at end of file diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java index 21f01f4ca..96d8fafdd 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/persistence/dataobject/UserDOExample.java @@ -4,7 +4,6 @@ import java.util.ArrayList; import java.util.List; public class UserDOExample { - /** * s2_user */ @@ -31,6 +30,7 @@ public class UserDOExample { protected Integer limitEnd; /** + * * @mbg.generated */ public UserDOExample() { @@ -38,13 +38,7 @@ public class UserDOExample { } /** - * @mbg.generated - */ - public String getOrderByClause() { - return orderByClause; - } - - /** + * * @mbg.generated */ public void setOrderByClause(String orderByClause) { @@ -52,13 +46,15 @@ public class UserDOExample { } /** + * * @mbg.generated */ - public boolean isDistinct() { - return distinct; + public String getOrderByClause() { + return orderByClause; } /** + * * @mbg.generated */ public void setDistinct(boolean distinct) { @@ -66,6 +62,15 @@ public class UserDOExample { } /** + * + * @mbg.generated + */ + public boolean isDistinct() { + return distinct; + } + + /** + * * @mbg.generated */ public List getOredCriteria() { @@ -73,6 +78,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public void or(Criteria criteria) { @@ -80,6 +86,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public Criteria or() { @@ -89,6 +96,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public Criteria createCriteria() { @@ -100,6 +108,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ protected Criteria createCriteriaInternal() { @@ -108,6 +117,7 @@ public class UserDOExample { } /** + * * @mbg.generated */ public void clear() { @@ -117,6 +127,15 @@ public class UserDOExample { } /** + * + * @mbg.generated + */ + public void setLimitStart(Integer limitStart) { + this.limitStart=limitStart; + } + + /** + * * @mbg.generated */ public Integer getLimitStart() { @@ -124,31 +143,25 @@ public class UserDOExample { } /** + * * @mbg.generated */ - public void setLimitStart(Integer limitStart) { - this.limitStart = limitStart; + public void setLimitEnd(Integer limitEnd) { + this.limitEnd=limitEnd; } /** + * * @mbg.generated */ public Integer getLimitEnd() { return limitEnd; } - /** - * @mbg.generated - */ - public void setLimitEnd(Integer limitEnd) { - this.limitEnd = limitEnd; - } - /** * s2_user null */ protected abstract static class GeneratedCriteria { - protected List criteria; protected GeneratedCriteria() { @@ -528,6 +541,66 @@ public class UserDOExample { addCriterion("email not between", value1, value2, "email"); return (Criteria) this; } + + public Criteria andIsAdminIsNull() { + addCriterion("is_admin is null"); + return (Criteria) this; + } + + public Criteria andIsAdminIsNotNull() { + addCriterion("is_admin is not null"); + return (Criteria) this; + } + + public Criteria andIsAdminEqualTo(Integer value) { + addCriterion("is_admin =", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminNotEqualTo(Integer value) { + addCriterion("is_admin <>", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminGreaterThan(Integer value) { + addCriterion("is_admin >", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminGreaterThanOrEqualTo(Integer value) { + addCriterion("is_admin >=", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminLessThan(Integer value) { + addCriterion("is_admin <", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminLessThanOrEqualTo(Integer value) { + addCriterion("is_admin <=", value, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminIn(List values) { + addCriterion("is_admin in", values, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminNotIn(List values) { + addCriterion("is_admin not in", values, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminBetween(Integer value1, Integer value2) { + addCriterion("is_admin between", value1, value2, "isAdmin"); + return (Criteria) this; + } + + public Criteria andIsAdminNotBetween(Integer value1, Integer value2) { + addCriterion("is_admin not between", value1, value2, "isAdmin"); + return (Criteria) this; + } } /** @@ -544,7 +617,6 @@ public class UserDOExample { * s2_user null */ public static class Criterion { - private String condition; private Object value; @@ -561,6 +633,38 @@ public class UserDOExample { private String typeHandler; + public String getCondition() { + return condition; + } + + public Object getValue() { + return value; + } + + public Object getSecondValue() { + return secondValue; + } + + public boolean isNoValue() { + return noValue; + } + + public boolean isSingleValue() { + return singleValue; + } + + public boolean isBetweenValue() { + return betweenValue; + } + + public boolean isListValue() { + return listValue; + } + + public String getTypeHandler() { + return typeHandler; + } + protected Criterion(String condition) { super(); this.condition = condition; @@ -596,37 +700,5 @@ public class UserDOExample { protected Criterion(String condition, Object value, Object secondValue) { this(condition, value, secondValue, null); } - - public String getCondition() { - return condition; - } - - public Object getValue() { - return value; - } - - public Object getSecondValue() { - return secondValue; - } - - public boolean isNoValue() { - return noValue; - } - - public boolean isSingleValue() { - return singleValue; - } - - public boolean isBetweenValue() { - return betweenValue; - } - - public boolean isListValue() { - return listValue; - } - - public String getTypeHandler() { - return typeHandler; - } } } \ No newline at end of file diff --git a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java index c8749ad43..82e93bcf3 100644 --- a/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java +++ b/auth/authentication/src/main/java/com/tencent/supersonic/auth/authentication/utils/UserTokenUtils.java @@ -2,6 +2,7 @@ package com.tencent.supersonic.auth.authentication.utils; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_ALGORITHM; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME; +import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_PREFIX; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_TIME_OUT; import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_USER_DISPLAY_NAME; @@ -42,6 +43,7 @@ public class UserTokenUtils { claims.put(TOKEN_USER_PASSWORD, StringUtils.isEmpty(user.getPassword()) ? "" : user.getPassword()); claims.put(TOKEN_USER_DISPLAY_NAME, user.getDisplayName()); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); + claims.put(TOKEN_IS_ADMIN, user.getIsAdmin()); return generate(claims); } @@ -52,6 +54,7 @@ public class UserTokenUtils { claims.put(TOKEN_USER_PASSWORD, "admin"); claims.put(TOKEN_USER_DISPLAY_NAME, "admin"); claims.put(TOKEN_CREATE_TIME, System.currentTimeMillis()); + claims.put(TOKEN_IS_ADMIN, 1); return generate(claims); } @@ -63,7 +66,9 @@ public class UserTokenUtils { String userName = String.valueOf(claims.get(TOKEN_USER_NAME)); String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); - return User.get(userId, userName, displayName, email); + Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null + ? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); + return User.get(userId, userName, displayName, email, isAdmin); } public UserWithPassword getUserWithPassword(HttpServletRequest request) { @@ -79,7 +84,9 @@ public class UserTokenUtils { String email = String.valueOf(claims.get(TOKEN_USER_EMAIL)); String displayName = String.valueOf(claims.get(TOKEN_USER_DISPLAY_NAME)); String password = String.valueOf(claims.get(TOKEN_USER_PASSWORD)); - return UserWithPassword.get(userId, userName, displayName, email, password); + Integer isAdmin = claims.get(TOKEN_IS_ADMIN) == null + ? 0 : Integer.parseInt(claims.get(TOKEN_IS_ADMIN).toString()); + return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin); } private Claims getClaims(String token) { diff --git a/auth/authentication/src/main/resources/mapper/UserDOMapper.xml b/auth/authentication/src/main/resources/mapper/UserDOMapper.xml index 15eb2b49c..c1933db89 100644 --- a/auth/authentication/src/main/resources/mapper/UserDOMapper.xml +++ b/auth/authentication/src/main/resources/mapper/UserDOMapper.xml @@ -2,11 +2,12 @@ - + + @@ -38,7 +39,7 @@ - id, name, password, display_name, email + id, name, password, display_name, email, is_admin - - - delete from s2_user - where id = #{id,jdbcType=BIGINT} - insert into s2_user (id, name, password, - display_name, email) + display_name, email, is_admin + ) values (#{id,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{password,jdbcType=VARCHAR}, - #{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}) + #{displayName,jdbcType=VARCHAR}, #{email,jdbcType=VARCHAR}, #{isAdmin,jdbcType=INTEGER} + ) insert into s2_user @@ -91,6 +84,9 @@ email, + + is_admin, + @@ -108,6 +104,9 @@ #{email,jdbcType=VARCHAR}, + + #{isAdmin,jdbcType=INTEGER}, + - - update s2_user - - - name = #{name,jdbcType=VARCHAR}, - - - password = #{password,jdbcType=VARCHAR}, - - - display_name = #{displayName,jdbcType=VARCHAR}, - - - email = #{email,jdbcType=VARCHAR}, - - - where id = #{id,jdbcType=BIGINT} - - - update s2_user - set name = #{name,jdbcType=VARCHAR}, - password = #{password,jdbcType=VARCHAR}, - display_name = #{displayName,jdbcType=VARCHAR}, - email = #{email,jdbcType=VARCHAR} - where id = #{id,jdbcType=BIGINT} - \ No newline at end of file diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java index dbd0226ee..a6be0a0ab 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticLayer.java @@ -8,9 +8,11 @@ import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.response.DomainResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; @@ -32,14 +34,27 @@ import java.util.List; public interface SemanticLayer { QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user); + QueryResultWithSchemaResp queryByMultiStruct(QueryMultiStructReq queryMultiStructReq, User user); + QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user); + QueryResultWithSchemaResp queryDimValue(QueryDimValueReq queryDimValueReq, User user); + List getModelSchema(); + List getModelSchema(List ids); + ModelSchema getModelSchema(Long model, Boolean cacheEnable); + PageInfo getDimensionPage(PageDimensionReq pageDimensionCmd); - PageInfo getMetricPage(PageMetricReq pageMetricCmd); + + PageInfo getMetricPage(PageMetricReq pageMetricCmd, User user); + List getDomainList(User user); + List getModelList(AuthType authType, Long domainId, User user); + + ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception; + } diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java index 91acf9016..70d263153 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/component/SemanticQuery.java @@ -3,6 +3,7 @@ package com.tencent.supersonic.chat.api.component; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import org.apache.calcite.sql.parser.SqlParseException; /** @@ -14,6 +15,8 @@ public interface SemanticQuery { QueryResult execute(User user) throws SqlParseException; + ExplainResp explain(User user); + SemanticParseInfo getParseInfo(); void setParseInfo(SemanticParseInfo parseInfo); diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java index 671aa91a4..e8edc777b 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticParseInfo.java @@ -37,6 +37,8 @@ public class SemanticParseInfo { private List elementMatches = new ArrayList<>(); private Map properties = new HashMap<>(); private EntityInfo entityInfo; + private String sql; + public Long getModelId() { return model != null ? model.getId() : 0L; } @@ -46,6 +48,7 @@ public class SemanticParseInfo { } private static class SchemaNameLengthComparator implements Comparator { + @Override public int compare(SchemaElement o1, SchemaElement o2) { int len1 = o1.getName().length(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/config/LLMParserConfig.java b/chat/core/src/main/java/com/tencent/supersonic/chat/config/LLMParserConfig.java index e44029538..6032b798c 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/config/LLMParserConfig.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/config/LLMParserConfig.java @@ -16,4 +16,10 @@ public class LLMParserConfig { @Value("${query2sql.path:/query2sql}") private String queryToSqlPath; + @Value("${dimension.topn:5}") + private Integer dimensionTopN; + + @Value("${metric.topn:5}") + private Integer metricTopN; + } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index 9d486d6e9..f286c1177 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -39,6 +39,7 @@ import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import java.util.ArrayList; import java.util.Arrays; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -87,7 +88,7 @@ public class LLMDslParser implements SemanticParser { return; } - LLMReq llmReq = getLlmReq(queryCtx, modelId); + LLMReq llmReq = getLlmReq(queryCtx, modelId, llmParserConfig); LLMResp llmResp = requestLLM(llmReq, modelId, llmParserConfig); if (Objects.isNull(llmResp)) { @@ -340,22 +341,28 @@ public class LLMDslParser implements SemanticParser { return null; } - private LLMReq getLlmReq(QueryContext queryCtx, Long modelId) { + private LLMReq getLlmReq(QueryContext queryCtx, Long modelId, LLMParserConfig llmParserConfig) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); Map modelIdToName = semanticSchema.getModelIdToName(); String queryText = queryCtx.getRequest().getQueryText(); + LLMReq llmReq = new LLMReq(); llmReq.setQueryText(queryText); + LLMReq.LLMSchema llmSchema = new LLMReq.LLMSchema(); llmSchema.setModelName(modelIdToName.get(modelId)); llmSchema.setDomainName(modelIdToName.get(modelId)); - List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema); + + List fieldNameList = getFieldNameList(queryCtx, modelId, semanticSchema, llmParserConfig); + fieldNameList.add(BaseSemanticCorrector.DATE_FIELD); llmSchema.setFieldNameList(fieldNameList); llmReq.setSchema(llmSchema); + List linking = new ArrayList<>(); linking.addAll(getValueList(queryCtx, modelId, semanticSchema)); llmReq.setLinking(linking); + String currentDate = DSLDateHelper.getReferenceDate(modelId); llmReq.setCurrentDate(currentDate); return llmReq; @@ -399,12 +406,29 @@ public class LLMDslParser implements SemanticParser { } - protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { + protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, + LLMParserConfig llmParserConfig) { Map itemIdToName = getItemIdToName(modelId, semanticSchema); + Set results = semanticSchema.getDimensions().stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel())) + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getDimensionTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + Set metrics = semanticSchema.getMetrics().stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel())) + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getMetricTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + results.addAll(metrics); + List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); if (CollectionUtils.isEmpty(matchedElements)) { - return new ArrayList<>(); + return new ArrayList<>(results); } Set fieldNameList = matchedElements.stream() .filter(schemaElementMatch -> { @@ -423,7 +447,8 @@ public class LLMDslParser implements SemanticParser { }) .filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%")) .collect(Collectors.toSet()); - return new ArrayList<>(fieldNameList); + results.addAll(fieldNameList); + return new ArrayList<>(results); } protected Map getItemIdToName(Long modelId, SemanticSchema semanticSchema) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/dsl/DslQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/dsl/DslQuery.java index 64b97bbf0..58764882a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/dsl/DslQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/llm/dsl/DslQuery.java @@ -15,7 +15,10 @@ import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import java.util.ArrayList; import java.util.List; @@ -42,12 +45,10 @@ public class DslQuery extends PluginSemanticQuery { @Override public QueryResult execute(User user) { - String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); - DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); - LLMResp llmResp = dslParseResult.getLlmResp(); + LLMResp llmResp = getLlmResp(); long startTime = System.currentTimeMillis(); - QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId()); + QueryDslReq queryDslReq = getQueryDslReq(llmResp); QueryResultWithSchemaResp queryResp = semanticLayer.queryByDsl(queryDslReq, user); log.info("queryByDsl cost:{},querySql:{}", System.currentTimeMillis() - startTime, llmResp.getSqlOutput()); @@ -71,4 +72,30 @@ public class DslQuery extends PluginSemanticQuery { parseInfo.setProperties(null); return queryResult; } + + private LLMResp getLlmResp() { + String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); + DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); + return dslParseResult.getLlmResp(); + } + + private QueryDslReq getQueryDslReq(LLMResp llmResp) { + QueryDslReq queryDslReq = QueryReqBuilder.buildDslReq(llmResp.getCorrectorSql(), parseInfo.getModelId()); + return queryDslReq; + } + + @Override + public ExplainResp explain(User user) { + ExplainSqlReq explainSqlReq = null; + try { + explainSqlReq = ExplainSqlReq.builder() + .queryTypeEnum(QueryTypeEnum.SQL) + .queryReq(getQueryDslReq(getLlmResp())) + .build(); + return semanticLayer.explain(explainSqlReq, user); + } catch (Exception e) { + log.error("explain error explainSqlReq:{}", explainSqlReq, e); + } + return null; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java index da2b48072..6e0341af0 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/plugin/PluginSemanticQuery.java @@ -1,7 +1,9 @@ package com.tencent.supersonic.chat.query.plugin; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import lombok.extern.slf4j.Slf4j; @Slf4j @@ -17,5 +19,8 @@ public abstract class PluginSemanticQuery implements SemanticQuery { return parseInfo; } - + @Override + public ExplainResp explain(User user) { + return null; + } } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java index d8348c077..4b5a8ac74 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/query/rule/RuleSemanticQuery.java @@ -21,8 +21,11 @@ import com.tencent.supersonic.chat.utils.ComponentFactory; import com.tencent.supersonic.chat.utils.QueryReqBuilder; import com.tencent.supersonic.common.pojo.QueryColumn; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import java.io.Serializable; @@ -215,6 +218,22 @@ public abstract class RuleSemanticQuery implements SemanticQuery, Serializable { return queryResult; } + + @Override + public ExplainResp explain(User user) { + ExplainSqlReq explainSqlReq = null; + try { + explainSqlReq = ExplainSqlReq.builder() + .queryTypeEnum(QueryTypeEnum.STRUCT) + .queryReq(convertQueryStruct()) + .build(); + return semanticLayer.explain(explainSqlReq, user); + } catch (Exception e) { + log.error("explain error explainSqlReq:{}", explainSqlReq, e); + } + return null; + } + public QueryResult multiStructExecute(User user) { String queryMode = parseInfo.getQueryMode(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java index 88612800a..d8521e73a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatConfigController.java @@ -110,17 +110,18 @@ public class ChatConfigController { } @PostMapping("/dimension/page") - public PageInfo getDimension(@RequestBody PageDimensionReq pageDimensionCmd, + public PageInfo getDimension(@RequestBody PageDimensionReq pageDimensionReq, HttpServletRequest request, HttpServletResponse response) { - return semanticLayer.getDimensionPage(pageDimensionCmd); + return semanticLayer.getDimensionPage(pageDimensionReq); } @PostMapping("/metric/page") - public PageInfo getMetric(@RequestBody PageMetricReq pageMetrricCmd, + public PageInfo getMetric(@RequestBody PageMetricReq pageMetricReq, HttpServletRequest request, HttpServletResponse response) { - return semanticLayer.getMetricPage(pageMetrricCmd); + User user = UserHolder.findUser(request, response); + return semanticLayer.getMetricPage(pageMetricReq, user); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index a3b1050c9..817578804 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -29,6 +29,7 @@ import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; import com.tencent.supersonic.chat.utils.ComponentFactory; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import java.util.List; import java.util.ArrayList; import java.util.Set; @@ -37,9 +38,7 @@ import java.util.Comparator; import java.util.Objects; import java.util.stream.Collectors; -//import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.DateConf; -//import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; @@ -68,8 +67,6 @@ public class QueryServiceImpl implements QueryService { @Autowired private QueryResponder queryResponder; - private final String entity = "ENTITY"; - @Value("${time.threshold: 100}") private Integer timeThreshold; @@ -113,12 +110,16 @@ public class QueryServiceImpl implements QueryService { .map(SemanticQuery::getParseInfo) .sorted(Comparator.comparingDouble(SemanticParseInfo::getScore).reversed()) .collect(Collectors.toList()); + selectedParses.forEach(parseInfo -> { - if (parseInfo.getQueryMode().contains(entity)) { + String queryMode = parseInfo.getQueryMode(); + if (QueryManager.isEntityQuery(queryMode)) { EntityInfo entityInfo = ContextUtils.getBean(SemanticService.class) .getEntityInfo(parseInfo, queryReq.getUser()); parseInfo.setEntityInfo(entityInfo); } + addExplainSql(queryReq, parseInfo); + }); List candidateParses = queryCtx.getCandidateQueries().stream() .map(SemanticQuery::getParseInfo).collect(Collectors.toList()); @@ -145,6 +146,19 @@ public class QueryServiceImpl implements QueryService { return parseResult; } + private void addExplainSql(QueryReq queryReq, SemanticParseInfo parseInfo) { + SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); + if (Objects.isNull(semanticQuery)) { + return; + } + semanticQuery.setParseInfo(parseInfo); + ExplainResp explain = semanticQuery.explain(queryReq.getUser()); + if (Objects.isNull(explain)) { + return; + } + parseInfo.setSql(explain.getSql()); + } + @Override public QueryResult performExecution(ExecuteQueryReq queryReq) throws Exception { ChatParseDO chatParseDO = chatService.getParseInfo(queryReq.getQueryId(), @@ -162,9 +176,9 @@ public class QueryServiceImpl implements QueryService { chatCtx.setAgentId(queryReq.getAgentId()); Long startTime = System.currentTimeMillis(); QueryResult queryResult = semanticQuery.execute(queryReq.getUser()); - Long endTime = System.currentTimeMillis(); + if (queryResult != null) { - timeCostDOList.add(StatisticsDO.builder().cost((int) (endTime - startTime)) + timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) .interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build()); saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getUser().getName(), queryReq.getChatId().longValue()); @@ -176,7 +190,6 @@ public class QueryServiceImpl implements QueryService { } chatCtx.setQueryText(queryReq.getQueryText()); chatCtx.setUser(queryReq.getUser().getName()); - //chatService.addQuery(queryResult, chatCtx); chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId()); } else { @@ -187,8 +200,8 @@ public class QueryServiceImpl implements QueryService { } public void saveInfo(List timeCostDOList, - String queryText, Long queryId, - String userName, Long chatId) { + String queryText, Long queryId, + String userName, Long chatId) { List list = timeCostDOList.stream() .filter(o -> o.getCost() > timeThreshold).collect(Collectors.toList()); list.forEach(o -> { @@ -272,13 +285,6 @@ public class QueryServiceImpl implements QueryService { dateConf.setPeriod("DAY"); queryStructReq.setDateInfo(dateConf); queryStructReq.setLimit(20L); - - // List aggregators = new ArrayList<>(); - // Aggregator aggregator = new Aggregator(dimensionValueReq.getQueryFilter().getBizName(), - // AggOperatorEnum.DISTINCT); - // aggregators.add(aggregator); - // queryStructReq.setAggregators(aggregators); - queryStructReq.setModelId(dimensionValueReq.getModelId()); queryStructReq.setNativeQuery(true); List groups = new ArrayList<>(); diff --git a/chat/core/src/main/python/run_config.py b/chat/core/src/main/python/run_config.py index e0fe5f1aa..2d4cbaf53 100644 --- a/chat/core/src/main/python/run_config.py +++ b/chat/core/src/main/python/run_config.py @@ -15,6 +15,7 @@ CHROMA_DB_PERSIST_DIR = 'chm_db' PRESET_QUERY_COLLECTION_NAME = "preset_query_collection" TEXT2DSL_COLLECTION_NAME = "text2dsl_collection" TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 +TEXT2DSL_IS_SHORTCUT = False CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) diff --git a/chat/core/src/main/python/sql/constructor.py b/chat/core/src/main/python/sql/constructor.py index b844a84fe..2553e4eca 100644 --- a/chat/core/src/main/python/sql/constructor.py +++ b/chat/core/src/main/python/sql/constructor.py @@ -22,10 +22,8 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM - def reload_sql_example_collection(vectorstore:Chroma, sql_examplars:List[Mapping[str, str]], - schema_linking_example_selector:SemanticSimilarityExampleSelector, sql_example_selector:SemanticSimilarityExampleSelector, example_nums:int ): @@ -35,20 +33,16 @@ def reload_sql_example_collection(vectorstore:Chroma, print("emptied sql_examples_collection size:", vectorstore._collection.count()) - schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]) - sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["question", "current_date", "table_name", "schema_links", "sql"]) + input_keys=["question"], + example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"]) for example in sql_examplars: - schema_linking_example_selector.add_example(example) + sql_example_selector.add_example(example) print("reloaded sql_examples_collection size:", vectorstore._collection.count()) - return vectorstore, schema_linking_example_selector, sql_example_selector + return vectorstore, sql_example_selector sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME, @@ -57,22 +51,14 @@ sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME, example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM -schema_linking_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"]) - sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["question", "current_date", "table_name", "schema_links", "sql"]) + input_keys=["question"], + example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"]) if sql_examples_vectorstore._collection.count() > 0: print("examples already in sql_vectorstore") print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count()) - if sql_examples_vectorstore._collection.count() < len(sql_examplars): - print("sql_examplars size:", len(sql_examplars)) - sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums) - print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) -else: - sql_examples_vectorstore, schema_linking_example_selector, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, schema_linking_example_selector, sql_example_selector, example_nums) - print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) +print("sql_examplars size:", len(sql_examplars)) +sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums) +print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) diff --git a/chat/core/src/main/python/sql/examples_reload_run.py b/chat/core/src/main/python/sql/examples_reload_run.py index 65f1e3bed..65df9087d 100644 --- a/chat/core/src/main/python/sql/examples_reload_run.py +++ b/chat/core/src/main/python/sql/examples_reload_run.py @@ -8,24 +8,22 @@ import json sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM +from run_config import TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT from few_shot_example.sql_exampler import examplars as sql_examplars -from run_config import LLMPARSER_HOST -from run_config import LLMPARSER_PORT +from run_config import LLMPARSER_HOST, LLMPARSER_PORT def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str, - sql_examplars:List[Mapping[str, str]], example_nums:int): + sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool): url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/" print("url: ", url) - payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums} + payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut} headers = {'content-type': 'application/json'} response = requests.post(url, data=json.dumps(payload), headers=headers) print(response.text) if __name__ == "__main__": - arguments = sys.argv text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT, - sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM) + sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT) diff --git a/chat/core/src/main/python/sql/output_parser.py b/chat/core/src/main/python/sql/output_parser.py index c90388850..aa0ff317f 100644 --- a/chat/core/src/main/python/sql/output_parser.py +++ b/chat/core/src/main/python/sql/output_parser.py @@ -10,4 +10,36 @@ def schema_link_parse(schema_link_output): print(e) schema_link_output = None - return schema_link_output \ No newline at end of file + return schema_link_output + +def combo_schema_link_parse(schema_linking_sql_combo_output: str): + try: + schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip() + pattern = r'Schema_links:(\[.*?\])' + schema_links_match = re.search(pattern, schema_linking_sql_combo_output) + + if schema_links_match: + schema_links = schema_links_match.group(1) + else: + schema_links = None + except Exception as e: + print(e) + schema_links = None + + return schema_links + +def combo_sql_parse(schema_linking_sql_combo_output: str): + try: + schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip() + pattern = r'SQL:(.*)' + sql_match = re.search(pattern, schema_linking_sql_combo_output) + + if sql_match: + sql = sql_match.group(1) + else: + sql = None + except Exception as e: + print(e) + sql = None + + return sql diff --git a/chat/core/src/main/python/sql/prompt_maker.py b/chat/core/src/main/python/sql/prompt_maker.py index 0cfed83b1..7c4f5fccc 100644 --- a/chat/core/src/main/python/sql/prompt_maker.py +++ b/chat/core/src/main/python/sql/prompt_maker.py @@ -73,3 +73,38 @@ def sql_exampler(user_query: str, schema_links=schema_link_str) return sql_example_prompt + + +def schema_linking_sql_combo_examplar(user_query: str, + domain_name: str, + data_date : str, + fields_list: List[str], + prior_schema_links: Mapping[str,str], + example_selector: SemanticSimilarityExampleSelector) -> str: + + prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' + + example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"], + template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}") + + instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句" + + schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。" + + schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate( + example_selector=example_selector, + example_prompt=example_prompt_template, + example_separator="\n\n", + prefix=instruction, + input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"], + suffix=schema_linking_sql_combo_prompt + ) + + schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name, + fields_list=fields_list, + prior_schema_links=prior_schema_links_str, + current_date=data_date, + question=user_query) + return schema_linking_sql_combo_example_prompt + + diff --git a/chat/core/src/main/python/sql/run.py b/chat/core/src/main/python/sql/run.py index a7ece82d8..02931b5c8 100644 --- a/chat/core/src/main/python/sql/run.py +++ b/chat/core/src/main/python/sql/run.py @@ -7,32 +7,37 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from sql.prompt_maker import schema_linking_exampler, sql_exampler -from sql.constructor import schema_linking_example_selector, sql_example_selector,sql_examples_vectorstore, reload_sql_example_collection -from sql.output_parser import schema_link_parse +from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar +from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection +from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse from util.llm_instance import llm - +from run_config import TEXT2DSL_IS_SHORTCUT class Text2DSLAgent(object): def __init__(self): self.schema_linking_exampler = schema_linking_exampler self.sql_exampler = sql_exampler + self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar + self.sql_examples_vectorstore = sql_examples_vectorstore - self.schema_linking_example_selector = schema_linking_example_selector self.sql_example_selector = sql_example_selector self.schema_link_parse = schema_link_parse + self.combo_schema_link_parse = combo_schema_link_parse + self.combo_sql_parse = combo_sql_parse self.llm = llm - def update_examples(self, sql_examplars, example_nums): - self.sql_examples_vectorstore, self.schema_linking_example_selector, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore, - sql_examplars, - self.schema_linking_example_selector, - self.sql_example_selector, - example_nums) + self.is_shortcut = TEXT2DSL_IS_SHORTCUT + + def update_examples(self, sql_examples, example_nums, is_shortcut): + self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore, + sql_examples, + self.sql_example_selector, + example_nums) + self.is_shortcut = is_shortcut def query2sql(self, query_text: str, schema : Union[dict, None] = None, @@ -53,14 +58,14 @@ class Text2DSLAgent(object): model_name = schema['modelName'] fields_list = schema['fieldNameList'] - schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.schema_linking_example_selector) + schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector) print("schema_linking_prompt->", schema_linking_prompt) schema_link_output = self.llm(schema_linking_prompt) schema_link_str = self.schema_link_parse(schema_link_output) sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector) print("sql_prompt->", sql_prompt) - sql_output = llm(sql_prompt) + sql_output = self.llm(sql_prompt) resp = dict() resp['query'] = query_text @@ -69,7 +74,7 @@ class Text2DSLAgent(object): resp['priorSchemaLinking'] = linking resp['dataDate'] = current_date - resp['schemaLinkingOutput'] = schema_link_output + resp['analysisOutput'] = schema_link_output resp['schemaLinkStr'] = schema_link_str resp['sqlOutput'] = sql_output @@ -78,5 +83,57 @@ class Text2DSLAgent(object): return resp + def query2sqlcombo(self, query_text: str, + schema : Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None + ): + + print("query_text: ", query_text) + print("schema: ", schema) + print("current_date: ", current_date) + print("prior_schema_links: ", linking) + + if linking is not None: + prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking} + else: + prior_schema_links = {} + + model_name = schema['modelName'] + fields_list = schema['fieldNameList'] + + schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list, + prior_schema_links, self.sql_example_selector) + print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt) + schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) + + schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output) + sql_str = self.combo_sql_parse(schema_linking_sql_combo_output) + + resp = dict() + resp['query'] = query_text + resp['model'] = model_name + resp['fields'] = fields_list + resp['priorSchemaLinking'] = prior_schema_links + resp['dataDate'] = current_date + + resp['analysisOutput'] = schema_linking_sql_combo_output + resp['schemaLinkStr'] = schema_linking_str + resp['sqlOutput'] = sql_str + + print("resp: ", resp) + + return resp + + def query2sql_run(self, query_text: str, + schema : Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None): + + if self.is_shortcut: + return self.query2sqlcombo(query_text, schema, current_date, linking) + else: + return self.query2sql(query_text, schema, current_date, linking) + text2sql_agent = Text2DSLAgent() diff --git a/chat/core/src/main/python/supersonic_llmparser.py b/chat/core/src/main/python/supersonic_llmparser.py index 963328a27..40ebfe613 100644 --- a/chat/core/src/main/python/supersonic_llmparser.py +++ b/chat/core/src/main/python/supersonic_llmparser.py @@ -51,7 +51,7 @@ async def din_query2sql(query_body: Mapping[str, Any]): else: linking = query_body['linking'] - resp = text2sql_agent.query2sql(query_text=query_text, + resp = text2sql_agent.query2sql_run(query_text=query_text, schema=schema, current_date=current_date, linking=linking) return resp @@ -70,7 +70,12 @@ async def query2sql_setting_update(query_body: Mapping[str, Any]): else: example_nums = query_body['exampleNums'] - text2sql_agent.update_examples(sql_examplars=sql_examplars, example_nums=example_nums) + if 'isShortcut' not in query_body: + raise HTTPException(status_code=400, detail="isShortcut is not in query_body") + else: + is_shortcut = query_body['isShortcut'] + + text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut) return "success" diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/DefaultSemanticConfig.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/DefaultSemanticConfig.java index a28b71682..2152da0b7 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/DefaultSemanticConfig.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/DefaultSemanticConfig.java @@ -38,4 +38,7 @@ public class DefaultSemanticConfig { @Value("${fetchModelList.path:/api/semantic/schema/model/list}") private String fetchModelListPath; + @Value("${explain.path:/api/semantic/query/explain}") + private String explainPath; + } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java index d0dceaaa5..81a445afc 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java @@ -8,12 +8,14 @@ import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.model.response.DomainResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; @@ -79,8 +81,9 @@ public class LocalSemanticLayer extends BaseSemanticLayer { public List doFetchModelSchema(List ids) { ModelSchemaFilterReq filter = new ModelSchemaFilterReq(); filter.setModelIds(ids); - modelService = ContextUtils.getBean(ModelService.class); - return modelService.fetchModelSchema(filter); + schemaService = ContextUtils.getBean(SchemaService.class); + User user = User.getFakeUser(); + return schemaService.fetchModelSchema(filter, user); } @Override @@ -95,6 +98,12 @@ public class LocalSemanticLayer extends BaseSemanticLayer { return schemaService.getModelList(user, authType, domainId); } + @Override + public ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception { + queryService = ContextUtils.getBean(QueryService.class); + return queryService.explain(explainSqlReq, user); + } + @Override public PageInfo getDimensionPage(PageDimensionReq pageDimensionCmd) { dimensionService = ContextUtils.getBean(DimensionService.class); @@ -102,9 +111,9 @@ public class LocalSemanticLayer extends BaseSemanticLayer { } @Override - public PageInfo getMetricPage(PageMetricReq pageMetricReq) { + public PageInfo getMetricPage(PageMetricReq pageMetricReq, User user) { metricService = ContextUtils.getBean(MetricService.class); - return metricService.queryMetric(pageMetricReq); + return metricService.queryMetric(pageMetricReq, user); } } diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java index 05630263d..c18479162 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/RemoteSemanticLayer.java @@ -1,38 +1,44 @@ package com.tencent.supersonic.knowledge.semantic; +import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER; +import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER; +import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER; +import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER; + import com.alibaba.fastjson.JSON; import com.github.pagehelper.PageInfo; import com.google.gson.Gson; import com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig; import com.tencent.supersonic.auth.api.authentication.constant.UserConstants; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.ResultData; +import com.tencent.supersonic.common.pojo.ReturnCode; +import com.tencent.supersonic.common.pojo.enums.AuthType; +import com.tencent.supersonic.common.pojo.exception.CommonException; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.JsonUtil; import com.tencent.supersonic.common.util.S2ThreadContext; import com.tencent.supersonic.common.util.ThreadContext; -import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; import com.tencent.supersonic.semantic.api.model.request.PageDimensionReq; import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; -import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; -import com.tencent.supersonic.semantic.api.model.response.ModelResp; -import com.tencent.supersonic.semantic.api.model.response.MetricResp; -import com.tencent.supersonic.semantic.api.model.response.DomainResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.semantic.api.model.response.DomainResp; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; +import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; -import com.tencent.supersonic.common.pojo.exception.CommonException; -import com.tencent.supersonic.common.pojo.ResultData; -import com.tencent.supersonic.common.pojo.ReturnCode; - import java.net.URI; +import java.net.URL; +import java.util.LinkedHashMap; import java.util.List; import java.util.Objects; -import java.util.LinkedHashMap; - import lombok.extern.slf4j.Slf4j; import org.apache.logging.log4j.util.Strings; import org.springframework.beans.BeanUtils; @@ -45,11 +51,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; -import static com.tencent.supersonic.common.pojo.Constants.TRUE_LOWER; -import static com.tencent.supersonic.common.pojo.Constants.LIST_LOWER; -import static com.tencent.supersonic.common.pojo.Constants.TOTAL_LOWER; -import static com.tencent.supersonic.common.pojo.Constants.PAGESIZE_LOWER; - @Slf4j public class RemoteSemanticLayer extends BaseSemanticLayer { @@ -61,6 +62,10 @@ public class RemoteSemanticLayer extends BaseSemanticLayer { new ParameterizedTypeReference>() { }; + private ParameterizedTypeReference> explainTypeRef = + new ParameterizedTypeReference>() { + }; + @Override public QueryResultWithSchemaResp queryByStruct(QueryStructReq queryStructReq, User user) { DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class); @@ -130,9 +135,10 @@ public class RemoteSemanticLayer extends BaseSemanticLayer { fillToken(headers); DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class); - URI requestUrl = UriComponentsBuilder.fromHttpUrl( - defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchModelSchemaPath()).build() - .encode().toUri(); + String semanticUrl = defaultSemanticConfig.getSemanticUrl(); + String fetchModelSchemaPath = defaultSemanticConfig.getFetchModelSchemaPath(); + URI requestUrl = UriComponentsBuilder.fromHttpUrl(semanticUrl + fetchModelSchemaPath) + .build().encode().toUri(); ModelSchemaFilterReq filter = new ModelSchemaFilterReq(); filter.setModelIds(ids); ParameterizedTypeReference>> responseTypeRef = @@ -179,6 +185,39 @@ public class RemoteSemanticLayer extends BaseSemanticLayer { return JsonUtil.toList(JsonUtil.toString(domainDescListObject), ModelResp.class); } + @Override + public ExplainResp explain(ExplainSqlReq explainResp, User user) throws Exception { + DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class); + String semanticUrl = defaultSemanticConfig.getSemanticUrl(); + String explainPath = defaultSemanticConfig.getExplainPath(); + URL url = new URL(new URL(semanticUrl), explainPath); + return explain(url.toString(), JsonUtil.toString(explainResp)); + } + + public ExplainResp explain(String url, String jsonReq) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + fillToken(headers); + URI requestUrl = UriComponentsBuilder.fromHttpUrl(url).build().encode().toUri(); + HttpEntity entity = new HttpEntity<>(jsonReq, headers); + log.info("url:{},explain:{}", url, entity.getBody()); + ResultData responseBody; + try { + RestTemplate restTemplate = ContextUtils.getBean(RestTemplate.class); + + ResponseEntity> responseEntity = restTemplate.exchange( + requestUrl, HttpMethod.POST, entity, explainTypeRef); + log.info("ApiResponse responseBody:{}", responseEntity); + responseBody = responseEntity.getBody(); + if (Objects.nonNull(responseBody.getData())) { + return responseBody.getData(); + } + return null; + } catch (Exception e) { + throw new RuntimeException("explain interface error,url:" + url, e); + } + } + public Object fetchHttpResult(String url, String bodyJson, HttpMethod httpMethod) { HttpHeaders headers = new HttpHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); @@ -219,7 +258,7 @@ public class RemoteSemanticLayer extends BaseSemanticLayer { } @Override - public PageInfo getMetricPage(PageMetricReq pageMetricCmd) { + public PageInfo getMetricPage(PageMetricReq pageMetricCmd, User user) { String body = JsonUtil.toString(pageMetricCmd); DefaultSemanticConfig defaultSemanticConfig = ContextUtils.getBean(DefaultSemanticConfig.class); log.info("url:{}", defaultSemanticConfig.getSemanticUrl() + defaultSemanticConfig.getFetchMetricPagePath()); diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SchemaService.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SchemaService.java index 57256c79b..6bc575e62 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SchemaService.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/service/SchemaService.java @@ -18,7 +18,7 @@ public class SchemaService { public static final String ALL_CACHE = "all"; - private static final Integer META_CACHE_TIME = 5; + private static final Integer META_CACHE_TIME = 2; private SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); private LoadingCache cache = CacheBuilder.newBuilder() diff --git a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java index 65d870252..e17430eda 100644 --- a/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java +++ b/common/src/test/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserSelectHelperTest.java @@ -75,6 +75,13 @@ class SqlParserSelectHelperTest { + "user_id like '%alice%' AND publish_date > 10000 ORDER BY pv DESC LIMIT 1"); System.out.println(filterExpression); + + filterExpression = SqlParserSelectHelper.getFilterExpression( + "SELECT department, pv FROM s2 WHERE " + + "user_id like '%alice%' AND publish_date > 10000 " + + "group by department having sum(pv) > 2000 ORDER BY pv DESC LIMIT 1"); + + System.out.println(filterExpression); } diff --git a/docs/images/text2sql_config.png b/docs/images/text2sql_config.png index af552aca3..d9f641438 100644 Binary files a/docs/images/text2sql_config.png and b/docs/images/text2sql_config.png differ diff --git a/docs/userguides/text2sql_cn.md b/docs/userguides/text2sql_cn.md index 71c1efacc..eb207271d 100644 --- a/docs/userguides/text2sql_cn.md +++ b/docs/userguides/text2sql_cn.md @@ -5,21 +5,25 @@ text2sql的功能实现,高度依赖对LLM的应用。通过LLM生成SQL的过 ### **配置方式** 1. 样本池的配置。 - - supersonic/chat/core/src/main/python/llm/few_shot_example/sql_exampler.py为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。 + - supersonic/chat/core/src/main/python/few_shot_example/sql_exampler.py 为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。 2. 样本数量的配置。 - - 在supersonic/chat/core/src/main/python/llm/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。 + - 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。 - 默认值为15,为项目在内部实践后较优的经验值。样本少太少,对导致LLM在生成SQL的过程中缺少引导和示范,生成的SQL会更不稳定;样本太多,会增加生成SQL需要的时间和LLM的token消耗(或超过LLM的token上限)。 - -
+3. SQL生成方式的配置 + - 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_IS_SHORTCUT 变量进行配置。 + - 默认值为False;当为False时,会调用2次LLM生成SQL;当为True时,会只调用1次LLM生成SQL。相较于2次LLM调用生成的SQL,耗时会减少30-40%,token的消耗量会减少30%左右,但生成的SQL正确率会有所下降。 +
-

图1-1 样本数量的配置文件

+

图1-1 配置文件

-3. 运行中更新配置的脚本。 - - 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过脚本 supersonic/chat/core/src/main/python/bin/text2sql_resetting.sh 在项目运行中让配置生效。 - +### **运行中更新配置的脚本** +1. 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过以下2种方式让配置在项目运行中让配置生效。 + - 执行 supersonic-daemon.sh reload llmparser + - 执行 python examples_reload_run.py ### **FAQ** 1. 生成一个SQL需要消耗的的LLM token数量太多了,按照openAI对token的收费标准,生成一个SQL太贵了,可以少用一些token吗? - - 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。 + - 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。配置方式3.中所示,只调用1次LLM生成SQL。 - 需要注意,样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。 diff --git a/launchers/chat/src/main/resources/db/chat-data-h2.sql b/launchers/chat/src/main/resources/db/chat-data-h2.sql index a5207a41c..47989554b 100644 --- a/launchers/chat/src/main/resources/db/chat-data-h2.sql +++ b/launchers/chat/src/main/resources/db/chat-data-h2.sql @@ -1,4 +1,4 @@ -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); diff --git a/launchers/chat/src/main/resources/db/chat-schema-h2.sql b/launchers/chat/src/main/resources/db/chat-schema-h2.sql index 8bcad7b6d..20e5c3bab 100644 --- a/launchers/chat/src/main/resources/db/chat-schema-h2.sql +++ b/launchers/chat/src/main/resources/db/chat-schema-h2.sql @@ -1,3 +1,4 @@ +-- chat tables CREATE TABLE IF NOT EXISTS `s2_chat_context` ( `chat_id` BIGINT NOT NULL , -- context chat id @@ -7,7 +8,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat_context` `semantic_parse` LONGVARCHAR DEFAULT NULL , -- parse data `ext_data` LONGVARCHAR DEFAULT NULL , -- extend data PRIMARY KEY (`chat_id`) -); + ); CREATE TABLE IF NOT EXISTS `s2_chat` ( @@ -21,7 +22,7 @@ CREATE TABLE IF NOT EXISTS `s2_chat` `is_delete` INT DEFAULT '0' COMMENT 'is deleted', `is_top` INT DEFAULT '0' COMMENT 'is top', PRIMARY KEY (`chat_id`) -) ; + ) ; CREATE TABLE `s2_chat_query` @@ -64,72 +65,21 @@ CREATE TABLE `s2_chat_statistics` ); CREATE TABLE IF NOT EXISTS `s2_chat_config` ( - `id` INT NOT NULL AUTO_INCREMENT, - `model_id` INT DEFAULT NULL , - `chat_detail_config` varchar(655) , + `id` INT NOT NULL AUTO_INCREMENT, + `model_id` INT DEFAULT NULL , + `chat_detail_config` varchar(655) , `chat_agg_config` varchar(655) , - `recommended_questions` varchar(1500) , + `recommended_questions` varchar(1500) , `created_at` TIMESTAMP NOT NULL , `updated_at` TIMESTAMP NOT NULL , `created_by` varchar(100) NOT NULL , `updated_by` varchar(100) NOT NULL , `status` INT NOT NULL DEFAULT '0' , -- domain extension information status : 0 is normal, 1 is off the shelf, 2 is deleted PRIMARY KEY (`id`) -) ; - - --- CREATE TABLE IF NOT EXISTS `s2_chat_config` ( --- `id` INT NOT NULL AUTO_INCREMENT, --- `domain_id` INT DEFAULT NULL , --- `default_metrics` varchar(655) DEFAULT NULL, --- `visibility` varchar(655) , -- invisible dimension metric information --- `entity_info` varchar(655) , --- `dictionary_info` varchar(655) , -- dictionary-related dimension setting information --- `created_at` TIMESTAMP NOT NULL , --- `updated_at` TIMESTAMP NOT NULL , --- `created_by` varchar(100) NOT NULL , --- `updated_by` varchar(100) NOT NULL , --- `status` INT NOT NULL DEFAULT '0' , -- domain extension information status : 0 is normal, 1 is off the shelf, 2 is deleted --- PRIMARY KEY (`id`) --- ) ; + ) ; COMMENT ON TABLE s2_chat_config IS 'chat config information table '; - - -CREATE TABLE IF NOT EXISTS `s2_dictionary` ( - `id` INT NOT NULL AUTO_INCREMENT, - `domain_id` INT NOT NULL , - `dim_value_infos` LONGVARCHAR , -- dimension value setting information - `created_at` TIMESTAMP NOT NULL , - `updated_at` TIMESTAMP NOT NULL , - `created_by` varchar(100) NOT NULL , - `updated_by` varchar(100) DEFAULT NULL , - `status` INT NOT NULL DEFAULT '0' , -- domain extension information status : 0 is normal, 1 is off the shelf, 2 is deleted - PRIMARY KEY (`id`), - UNIQUE (domain_id) - ); -COMMENT ON TABLE s2_dictionary IS 'dictionary configuration information table'; - - -CREATE TABLE IF NOT EXISTS `s2_dictionary_task` ( - `id` INT NOT NULL AUTO_INCREMENT, - `name` varchar(255) NOT NULL , -- task name - `description` varchar(255) , - `command`LONGVARCHAR NOT NULL , -- task Request Parameters - `command_md5` varchar(255) NOT NULL , -- task Request Parameters md5 - `dimension_ids` varchar(500) , - `status` INT NOT NULL , -- the final status of the task - `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP , - `created_by` varchar(100) NOT NULL , - `progress` DOUBLE default 0.00 , -- task real-time progress - `elapsed_ms` bigINT DEFAULT NULL , -- the task takes time in milliseconds - `message` LONGVARCHAR , -- remark related information - PRIMARY KEY (`id`) -); -COMMENT ON TABLE s2_dictionary_task IS 'dictionary task information table'; - - create table s2_user ( id INT AUTO_INCREMENT, @@ -137,10 +87,33 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; + +CREATE TABLE IF NOT EXISTS `s2_semantic_pasre_info` ( + `id` INT NOT NULL AUTO_INCREMENT, + `trace_id` varchar(200) NOT NULL , + `model_id` INT NOT NULL , + `dimensions`LONGVARCHAR , + `metrics`LONGVARCHAR , + `orders`LONGVARCHAR , + `filters`LONGVARCHAR , + `date_info`LONGVARCHAR , + `limit` INT NOT NULL , + `native_query` TINYINT NOT NULL DEFAULT '0' , + `sql`LONGVARCHAR , + `created_at` TIMESTAMP NOT NULL , + `created_by` varchar(100) NOT NULL , + `status` INT NOT NULL , + `elapsed_ms` bigINT DEFAULT NULL , + PRIMARY KEY (`id`) + ); +COMMENT ON TABLE s2_semantic_pasre_info IS 'semantic layer sql parsing information table'; + + CREATE TABLE IF NOT EXISTS `s2_plugin` ( `id` INT AUTO_INCREMENT, @@ -157,5 +130,50 @@ CREATE TABLE IF NOT EXISTS `s2_plugin` `config` LONGVARCHAR NULL, `comment` LONGVARCHAR NULL, PRIMARY KEY (`id`) -); COMMENT ON TABLE s2_plugin IS 'plugin information table'; + ); COMMENT ON TABLE s2_plugin IS 'plugin information table'; + +CREATE TABLE IF NOT EXISTS s2_agent +( + id int AUTO_INCREMENT, + name varchar(100) null, + description varchar(500) null, + status int null, + examples varchar(500) null, + config varchar(2000) null, + created_by varchar(100) null, + created_at TIMESTAMP null, + updated_by varchar(100) null, + updated_at TIMESTAMP null, + enable_search int null, + PRIMARY KEY (`id`) + ); COMMENT ON TABLE s2_agent IS 'agent information table'; + + +-------demo for semantic and chat +CREATE TABLE IF NOT EXISTS `s2_user_department` ( + `user_name` varchar(200) NOT NULL, + `department` varchar(200) NOT NULL -- department of user + ); +COMMENT ON TABLE s2_user_department IS 'user_department_info'; + + +CREATE TABLE IF NOT EXISTS `s2_dictionary_task` ( + `id` INT NOT NULL AUTO_INCREMENT, + `name` varchar(255) NOT NULL , -- task name + `description` varchar(255) , + `command`LONGVARCHAR NOT NULL , -- task Request Parameters + `command_md5` varchar(255) NOT NULL , -- task Request Parameters md5 + `status` INT NOT NULL , -- the final status of the task + `dimension_ids` varchar(500) NULL , + `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP , + `created_by` varchar(100) NOT NULL , + `progress` DOUBLE default 0.00 , -- task real-time progress + `elapsed_ms` bigINT DEFAULT NULL , -- the task takes time in milliseconds + `message` LONGVARCHAR , -- remark related information + PRIMARY KEY (`id`) + ); +COMMENT ON TABLE s2_dictionary_task IS 'dictionary task information table'; + + + diff --git a/launchers/semantic/src/main/resources/db/semantic-data-h2.sql b/launchers/semantic/src/main/resources/db/semantic-data-h2.sql index b9904da13..c11f8e64d 100644 --- a/launchers/semantic/src/main/resources/db/semantic-data-h2.sql +++ b/launchers/semantic/src/main/resources/db/semantic-data-h2.sql @@ -36,7 +36,7 @@ insert into s2_auth_groups (group_id, config) values (2, '{"domainId":"1","name":"tom_sales_permission","groupId":2,"authRules":[{"metrics":["stay_hours"],"dimensions":["page"]}],"dimensionFilters":["department in (''sales'')"],"dimensionFilterDescription":"开通 tom sales部门权限", "authorizedUsers":["tom"],"authorizedDepartmentIds":[]}'); -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); diff --git a/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql b/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql index e27fcd725..2c846cead 100644 --- a/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql +++ b/launchers/semantic/src/main/resources/db/semantic-schema-h2.sql @@ -80,6 +80,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; @@ -108,6 +109,7 @@ CREATE TABLE IF NOT EXISTS `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL , `data_format` varchar(500) DEFAULT NULL, `alias` varchar(500) DEFAULT NULL, + `tags` varchar(500) DEFAULT NULL, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_metric IS 'metric information table'; diff --git a/launchers/standalone/src/main/resources/db/data-h2.sql b/launchers/standalone/src/main/resources/db/data-h2.sql index 0ba57383b..68709c0c2 100644 --- a/launchers/standalone/src/main/resources/db/data-h2.sql +++ b/launchers/standalone/src/main/resources/db/data-h2.sql @@ -1,8 +1,8 @@ -- sample user -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); -insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (4, 'lucy','123456','lucy','lucy@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (5, 'alice','123456','alice','alice@xx.com'); -- sample models diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index b8d090f07..1370bc82d 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -87,6 +87,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; @@ -190,6 +191,7 @@ CREATE TABLE IF NOT EXISTS `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL , `data_format` varchar(500) DEFAULT NULL, `alias` varchar(500) DEFAULT NULL, + `tags` varchar(500) DEFAULT NULL, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_metric IS 'metric information table'; diff --git a/launchers/standalone/src/main/resources/db/schema-mysql.sql b/launchers/standalone/src/main/resources/db/schema-mysql.sql index da3579178..266c67ff6 100644 --- a/launchers/standalone/src/main/resources/db/schema-mysql.sql +++ b/launchers/standalone/src/main/resources/db/schema-mysql.sql @@ -254,6 +254,7 @@ CREATE TABLE `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL COMMENT '数值类型', `data_format` varchar(500) DEFAULT NULL COMMENT '数值类型参数', `alias` varchar(500) CHARACTER SET utf8 COLLATE utf8_unicode_ci DEFAULT NULL, + `tags` varchar(500) CHARACTER SET utf8 COLLATE utf8_unicode_ci DEFAULT NULL, PRIMARY KEY (`id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8 COMMENT='指标表'; @@ -368,7 +369,8 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin int(11) null, PRIMARY KEY (`id`) ); -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); diff --git a/launchers/standalone/src/main/resources/db/sql-update.sql b/launchers/standalone/src/main/resources/db/sql-update.sql index b56df3579..7799b3656 100644 --- a/launchers/standalone/src/main/resources/db/sql-update.sql +++ b/launchers/standalone/src/main/resources/db/sql-update.sql @@ -48,5 +48,10 @@ alter table s2_database drop column domain_id; alter table s2_chat add column agent_id int after chat_id; --20230907 +ALTER TABLE s2_model add alias varchar(200) default null after domain_id; -ALTER TABLE s2_model add alias varchar(200) default null after domain_id; \ No newline at end of file +--20230919 +alter table s2_metric add tags varchar(500) null; + +--20230920 +alter table s2_user add is_admin int null; \ No newline at end of file diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java index 0bb34a21c..50938af67 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/util/DataUtils.java @@ -23,7 +23,7 @@ import static java.time.LocalDate.now; public class DataUtils { - private static final User user_test = new User(1L, "admin", "admin", "admin@email"); + private static final User user_test = User.getFakeUser(); public static User getUser() { return user_test; diff --git a/launchers/standalone/src/test/resources/db/data-h2.sql b/launchers/standalone/src/test/resources/db/data-h2.sql index c2ee76401..10f6a3ef5 100644 --- a/launchers/standalone/src/test/resources/db/data-h2.sql +++ b/launchers/standalone/src/test/resources/db/data-h2.sql @@ -1,5 +1,5 @@ -- sample user -insert into s2_user (id, `name`, password, display_name, email) values (1, 'admin','admin','admin','admin@xx.com'); +insert into s2_user (id, `name`, password, display_name, email, is_admin) values (1, 'admin','admin','admin','admin@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack','123456','jack','jack@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); diff --git a/launchers/standalone/src/test/resources/db/schema-h2.sql b/launchers/standalone/src/test/resources/db/schema-h2.sql index 64b84e38d..33429ca0b 100644 --- a/launchers/standalone/src/test/resources/db/schema-h2.sql +++ b/launchers/standalone/src/test/resources/db/schema-h2.sql @@ -102,6 +102,7 @@ create table s2_user display_name varchar(100) null, password varchar(100) null, email varchar(100) null, + is_admin INT null, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_user IS 'user information table'; @@ -205,6 +206,7 @@ CREATE TABLE IF NOT EXISTS `s2_metric` ( `data_format_type` varchar(50) DEFAULT NULL , `data_format` varchar(500) DEFAULT NULL, `alias` varchar(500) DEFAULT NULL, + `tags` varchar(500) DEFAULT NULL, PRIMARY KEY (`id`) ); COMMENT ON TABLE s2_metric IS 'metric information table'; diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/QueryStat.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/QueryStat.java index f08b126f9..b9819ac69 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/QueryStat.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/QueryStat.java @@ -79,7 +79,7 @@ public class QueryStat { return this; } - public QueryStat setClassId(Long modelId) { + public QueryStat setModelId(Long modelId) { this.modelId = modelId; return this; } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java index 1540baeeb..7481d997d 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/MetricBaseReq.java @@ -4,6 +4,7 @@ package com.tencent.supersonic.semantic.api.model.request; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.common.pojo.DataFormat; import lombok.Data; +import java.util.List; @Data @@ -17,4 +18,6 @@ public class MetricBaseReq extends SchemaItem { private DataFormat dataFormat; + private List tags; + } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java index aad02b810..40edea0c2 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageMetricReq.java @@ -9,6 +9,4 @@ public class PageMetricReq extends PageSchemaItemReq { private String type; - private String key; - } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java index b2580779d..94e26026d 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/request/PageSchemaItemReq.java @@ -16,4 +16,5 @@ public class PageSchemaItemReq extends PageBaseReq { private List modelIds = Lists.newArrayList(); private Integer sensitiveLevel; private Integer status; + private String key; } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/ExplainResp.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/ExplainResp.java new file mode 100644 index 000000000..96f5fa8e6 --- /dev/null +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/ExplainResp.java @@ -0,0 +1,19 @@ +package com.tencent.supersonic.semantic.api.model.response; + +import java.io.Serializable; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.ToString; + +@Data +@ToString +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class ExplainResp implements Serializable { + + private String sql; + +} diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java index 9afbe3368..0a2aecd24 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/response/MetricResp.java @@ -1,11 +1,15 @@ package com.tencent.supersonic.semantic.api.model.response; +import com.google.common.collect.Lists; import com.tencent.supersonic.common.pojo.DataFormat; import com.tencent.supersonic.semantic.api.model.pojo.MetricTypeParams; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import lombok.Data; import lombok.ToString; +import org.apache.commons.lang3.StringUtils; +import java.util.Arrays; +import java.util.List; @Data @@ -14,6 +18,8 @@ public class MetricResp extends SchemaItem { private Long modelId; + private Long domainId; + private String modelName; //ATOMIC DERIVED @@ -27,5 +33,15 @@ public class MetricResp extends SchemaItem { private String alias; + private List tags; + private boolean hasAdminRes = false; + + public void setTag(String tag) { + if (StringUtils.isBlank(tag)) { + tags = Lists.newArrayList(); + } else { + tags = Arrays.asList(tag.split(",")); + } + } } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ExplainSqlReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ExplainSqlReq.java new file mode 100644 index 000000000..eabf702fa --- /dev/null +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ExplainSqlReq.java @@ -0,0 +1,20 @@ +package com.tencent.supersonic.semantic.api.query.request; + +import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.ToString; + +@Data +@ToString +@Builder +@AllArgsConstructor +@NoArgsConstructor +public class ExplainSqlReq { + + private QueryTypeEnum queryTypeEnum; + + private T queryReq; +} diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ItemUseReq.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ItemUseReq.java index d1af76a40..f98a808dc 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ItemUseReq.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/query/request/ItemUseReq.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.semantic.api.query.request; +import java.util.List; import lombok.Data; import lombok.NoArgsConstructor; import lombok.ToString; @@ -11,6 +12,7 @@ public class ItemUseReq { private String startTime; private Long modelId; + private List modelIds; private Boolean cacheEnable = true; private String metric; @@ -18,4 +20,8 @@ public class ItemUseReq { this.startTime = startTime; this.modelId = modelId; } + public ItemUseReq(String startTime, List modelIds) { + this.startTime = startTime; + this.modelIds = modelIds; + } } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java index 6c4e731a3..e1a23c908 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DatabaseServiceImpl.java @@ -72,7 +72,8 @@ public class DatabaseServiceImpl implements DatabaseService { private void fillPermission(List databaseResps, User user) { databaseResps.forEach(databaseResp -> { if (databaseResp.getAdmins().contains(user.getName()) - || user.getName().equalsIgnoreCase(databaseResp.getCreatedBy())) { + || user.getName().equalsIgnoreCase(databaseResp.getCreatedBy()) + || user.isSuperAdmin()) { databaseResp.setHasPermission(true); databaseResp.setHasEditPermission(true); databaseResp.setHasUsePermission(true); @@ -111,7 +112,8 @@ public class DatabaseServiceImpl implements DatabaseService { List viewers = databaseResp.getViewers(); if (!admins.contains(user.getName()) && !viewers.contains(user.getName()) - && !databaseResp.getCreatedBy().equalsIgnoreCase(user.getName())) { + && !databaseResp.getCreatedBy().equalsIgnoreCase(user.getName()) + && !user.isSuperAdmin()) { String message = String.format("您暂无当前数据库%s权限, 请联系数据库管理员%s开通", databaseResp.getName(), String.join(",", admins)); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java index fd51d9496..ebc629b8d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/DomainServiceImpl.java @@ -96,12 +96,12 @@ public class DomainServiceImpl implements DomainService { @Override public List getDomainListWithAdminAuth(User user) { - Set domainWithAuthAll = getDomainAuthSet(user.getName(), AuthType.ADMIN); + Set domainWithAuthAll = getDomainAuthSet(user, AuthType.ADMIN); if (!CollectionUtils.isEmpty(domainWithAuthAll)) { List domainIds = domainWithAuthAll.stream().map(DomainResp::getId).collect(Collectors.toList()); domainWithAuthAll.addAll(getParentDomain(domainIds)); } - List modelResps = modelService.getModelAuthList(user.getName(), AuthType.ADMIN); + List modelResps = modelService.getModelAuthList(user, AuthType.ADMIN); if (!CollectionUtils.isEmpty(modelResps)) { List domainIds = modelResps.stream().map(ModelResp::getDomainId).collect(Collectors.toList()); domainWithAuthAll.addAll(getParentDomain(domainIds)); @@ -111,18 +111,18 @@ public class DomainServiceImpl implements DomainService { } @Override - public Set getDomainAuthSet(String userName, AuthType authTypeEnum) { + public Set getDomainAuthSet(User user, AuthType authTypeEnum) { List domainResps = getDomainList(); - Set orgIds = userService.getUserAllOrgId(userName); + Set orgIds = userService.getUserAllOrgId(user.getName()); List domainWithAuth = Lists.newArrayList(); if (authTypeEnum.equals(AuthType.ADMIN)) { domainWithAuth = domainResps.stream() - .filter(domainResp -> checkAdminPermission(orgIds, userName, domainResp)) + .filter(domainResp -> checkAdminPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); } if (authTypeEnum.equals(AuthType.VISIBLE)) { domainWithAuth = domainResps.stream() - .filter(domainResp -> checkViewerPermission(orgIds, userName, domainResp)) + .filter(domainResp -> checkViewerPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); } List domainIds = domainWithAuth.stream().map(DomainResp::getId) @@ -240,11 +240,13 @@ public class DomainServiceImpl implements DomainService { } - private boolean checkAdminPermission(Set orgIds, String userName, DomainResp domainResp) { - + private boolean checkAdminPermission(Set orgIds, User user, DomainResp domainResp) { List admins = domainResp.getAdmins(); List adminOrgs = domainResp.getAdminOrgs(); - if (admins.contains(userName) || domainResp.getCreatedBy().equals(userName)) { + if (user.isSuperAdmin()) { + return true; + } + if (admins.contains(user.getName()) || domainResp.getCreatedBy().equals(user.getName())) { return true; } if (CollectionUtils.isEmpty(adminOrgs)) { @@ -258,12 +260,17 @@ public class DomainServiceImpl implements DomainService { return false; } - private boolean checkViewerPermission(Set orgIds, String userName, DomainResp domainDesc) { + private boolean checkViewerPermission(Set orgIds, User user, DomainResp domainDesc) { List admins = domainDesc.getAdmins(); List viewers = domainDesc.getViewers(); List adminOrgs = domainDesc.getAdminOrgs(); List viewOrgs = domainDesc.getViewOrgs(); - if (admins.contains(userName) || viewers.contains(userName) || domainDesc.getCreatedBy().equals(userName)) { + if (user.isSuperAdmin()) { + return true; + } + if (admins.contains(user.getName()) + || viewers.contains(user.getName()) + || domainDesc.getCreatedBy().equals(user.getName())) { return true; } if (CollectionUtils.isEmpty(adminOrgs) && CollectionUtils.isEmpty(viewOrgs)) { diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java index 6545e1b70..327e9017d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/MetricServiceImpl.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.common.pojo.DataAddEvent; import com.tencent.supersonic.common.pojo.DataDeleteEvent; import com.tencent.supersonic.common.pojo.DataUpdateEvent; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.pojo.enums.DictWordType; import com.tencent.supersonic.common.util.ChatGptHelper; import com.tencent.supersonic.semantic.api.model.pojo.Measure; @@ -28,6 +29,7 @@ import com.tencent.supersonic.semantic.model.domain.utils.MetricConverter; import com.tencent.supersonic.semantic.model.domain.MetricService; import com.tencent.supersonic.semantic.model.domain.pojo.Metric; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; @@ -123,7 +125,7 @@ public class MetricServiceImpl implements MetricService { } @Override - public PageInfo queryMetric(PageMetricReq pageMetricReq) { + public PageInfo queryMetric(PageMetricReq pageMetricReq, User user) { MetricFilter metricFilter = new MetricFilter(); BeanUtils.copyProperties(pageMetricReq, metricFilter); Set domainResps = domainService.getDomainChildren(pageMetricReq.getDomainIds()); @@ -137,7 +139,9 @@ public class MetricServiceImpl implements MetricService { .doSelectPageInfo(() -> queryMetric(metricFilter)); PageInfo pageInfo = new PageInfo<>(); BeanUtils.copyProperties(metricDOPageInfo, pageInfo); - pageInfo.setList(convertList(metricDOPageInfo.getList())); + List metricResps = convertList(metricDOPageInfo.getList()); + fillAdminRes(metricResps, user); + pageInfo.setList(metricResps); return pageInfo; } @@ -145,6 +149,21 @@ public class MetricServiceImpl implements MetricService { return metricRepository.getMetric(metricFilter); } + + private void fillAdminRes(List metricResps, User user) { + List modelResps = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); + if (CollectionUtils.isEmpty(modelResps)) { + return; + } + Set modelIdSet = modelResps.stream().map(ModelResp::getId).collect(Collectors.toSet()); + for (MetricResp metricResp : metricResps) { + if (modelIdSet.contains(metricResp.getModelId())) { + metricResp.setHasAdminRes(true); + } + } + + } + @Override public MetricResp getMetric(Long modelId, String bizName) { List metricDescs = getMetricByModelId(modelId); @@ -250,6 +269,16 @@ public class MetricServiceImpl implements MetricService { }); } + @Override + public Set getMetricTags() { + List metricResps = getMetrics(); + if (CollectionUtils.isEmpty(metricResps)) { + return new HashSet<>(); + } + return metricResps.stream().flatMap(metricResp -> + metricResp.getTags().stream()).collect(Collectors.toSet()); + } + private void saveMetricBatch(List metrics, User user) { if (CollectionUtils.isEmpty(metrics)) { @@ -293,7 +322,7 @@ public class MetricServiceImpl implements MetricService { Map modelMap = modelService.getModelMap(); if (!CollectionUtils.isEmpty(metricDOS)) { metricDescs = metricDOS.stream() - .map(metricDO -> MetricConverter.convert2MetricDesc(metricDO, modelMap)) + .map(metricDO -> MetricConverter.convert2MetricResp(metricDO, modelMap)) .collect(Collectors.toList()); } return metricDescs; diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java index b10dd9704..5c5cf2bca 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java @@ -4,44 +4,43 @@ import com.alibaba.fastjson.JSONObject; import com.google.common.collect.Lists; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.service.UserService; +import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.common.util.BeanMapper; import com.tencent.supersonic.common.util.JsonUtil; -import com.tencent.supersonic.common.pojo.enums.AuthType; import com.tencent.supersonic.semantic.api.model.request.ModelReq; import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; -import com.tencent.supersonic.semantic.api.model.response.ModelResp; -import com.tencent.supersonic.semantic.api.model.response.DomainResp; -import com.tencent.supersonic.semantic.api.model.response.DimensionResp; -import com.tencent.supersonic.semantic.api.model.response.MetricResp; -import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp; -import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; -import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp; import com.tencent.supersonic.semantic.api.model.response.DatasourceResp; +import com.tencent.supersonic.semantic.api.model.response.DimSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.semantic.api.model.response.DomainResp; +import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.MetricSchemaResp; +import com.tencent.supersonic.semantic.api.model.response.ModelResp; +import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; import com.tencent.supersonic.semantic.model.domain.DatabaseService; -import com.tencent.supersonic.semantic.model.domain.ModelService; -import com.tencent.supersonic.semantic.model.domain.DomainService; -import com.tencent.supersonic.semantic.model.domain.DimensionService; -import com.tencent.supersonic.semantic.model.domain.MetricService; import com.tencent.supersonic.semantic.model.domain.DatasourceService; - +import com.tencent.supersonic.semantic.model.domain.DimensionService; +import com.tencent.supersonic.semantic.model.domain.DomainService; +import com.tencent.supersonic.semantic.model.domain.MetricService; +import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.model.domain.dataobject.ModelDO; import com.tencent.supersonic.semantic.model.domain.pojo.Model; import com.tencent.supersonic.semantic.model.domain.repository.ModelRepository; import com.tencent.supersonic.semantic.model.domain.utils.ModelConvert; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Service; import org.springframework.util.CollectionUtils; -import java.util.List; -import java.util.Objects; -import java.util.Date; -import java.util.Set; -import java.util.Map; -import java.util.HashSet; -import java.util.ArrayList; -import java.util.stream.Collectors; @Slf4j @Service @@ -97,10 +96,10 @@ public class ModelServiceImpl implements ModelService { } @Override - public List getModelListWithAuth(String userName, Long domainId, AuthType authType) { - List modelResps = getModelAuthList(userName, authType); + public List getModelListWithAuth(User user, Long domainId, AuthType authType) { + List modelResps = getModelAuthList(user, authType); Set modelRespSet = new HashSet<>(modelResps); - List modelRespsAuthInheritDomain = getModelRespAuthInheritDomain(userName, authType); + List modelRespsAuthInheritDomain = getModelRespAuthInheritDomain(user, authType); modelRespSet.addAll(modelRespsAuthInheritDomain); if (domainId != null && domainId > 0) { modelRespSet = modelRespSet.stream().filter(modelResp -> @@ -109,8 +108,8 @@ public class ModelServiceImpl implements ModelService { return fillMetricInfo(new ArrayList<>(modelRespSet)); } - public List getModelRespAuthInheritDomain(String userName, AuthType authType) { - Set domainResps = domainService.getDomainAuthSet(userName, authType); + public List getModelRespAuthInheritDomain(User user, AuthType authType) { + Set domainResps = domainService.getDomainAuthSet(user, authType); if (CollectionUtils.isEmpty(domainResps)) { return Lists.newArrayList(); } @@ -121,18 +120,18 @@ public class ModelServiceImpl implements ModelService { } @Override - public List getModelAuthList(String userName, AuthType authTypeEnum) { + public List getModelAuthList(User user, AuthType authTypeEnum) { List modelResps = getModelList(); - Set orgIds = userService.getUserAllOrgId(userName); + Set orgIds = userService.getUserAllOrgId(user.getName()); List modelWithAuth = Lists.newArrayList(); if (authTypeEnum.equals(AuthType.ADMIN)) { modelWithAuth = modelResps.stream() - .filter(modelResp -> checkAdminPermission(orgIds, userName, modelResp)) + .filter(modelResp -> checkAdminPermission(orgIds, user, modelResp)) .collect(Collectors.toList()); } if (authTypeEnum.equals(AuthType.VISIBLE)) { modelWithAuth = modelResps.stream() - .filter(domainResp -> checkViewerPermission(orgIds, userName, domainResp)) + .filter(domainResp -> checkViewerPermission(orgIds, user, domainResp)) .collect(Collectors.toList()); } return modelWithAuth; @@ -325,9 +324,13 @@ public class ModelServiceImpl implements ModelService { return new ArrayList<>(getModelMap().keySet()); } - public static boolean checkAdminPermission(Set orgIds, String userName, ModelResp modelResp) { + public static boolean checkAdminPermission(Set orgIds, User user, ModelResp modelResp) { List admins = modelResp.getAdmins(); List adminOrgs = modelResp.getAdminOrgs(); + if (user.isSuperAdmin()) { + return true; + } + String userName = user.getName(); if (admins.contains(userName) || modelResp.getCreatedBy().equals(userName)) { return true; } @@ -342,14 +345,18 @@ public class ModelServiceImpl implements ModelService { return false; } - public static boolean checkViewerPermission(Set orgIds, String userName, ModelResp modelResp) { + public static boolean checkViewerPermission(Set orgIds, User user, ModelResp modelResp) { List admins = modelResp.getAdmins(); List viewers = modelResp.getViewers(); List adminOrgs = modelResp.getAdminOrgs(); List viewOrgs = modelResp.getViewOrgs(); + if (user.isSuperAdmin()) { + return true; + } if (modelResp.openToAll()) { return true; } + String userName = user.getName(); if (admins.contains(userName) || viewers.contains(userName) || modelResp.getCreatedBy().equals(userName)) { return true; } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java index 565167b8a..6a72d2adf 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/DomainService.java @@ -30,7 +30,7 @@ public interface DomainService { List getDomainListWithAdminAuth(User user); - Set getDomainAuthSet(String userName, AuthType authTypeEnum); + Set getDomainAuthSet(User user, AuthType authTypeEnum); Set getDomainChildren(List domainId); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java index 1b28e8e64..969b935e5 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/MetricService.java @@ -7,6 +7,7 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import java.util.List; +import java.util.Set; public interface MetricService { @@ -22,7 +23,7 @@ public interface MetricService { void createMetricBatch(List metricReqs, User user) throws Exception; - PageInfo queryMetric(PageMetricReq pageMetrricReq); + PageInfo queryMetric(PageMetricReq pageMetricReq, User user); MetricResp getMetric(Long modelId, String bizName); @@ -35,4 +36,6 @@ public interface MetricService { void deleteMetric(Long id) throws Exception; List mockAlias(MetricReq metricReq, String mockType, User user); + + Set getMetricTags(); } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java index 7e05fa38e..f4458d8ba 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/ModelService.java @@ -13,9 +13,9 @@ import java.util.Map; public interface ModelService { - List getModelListWithAuth(String userName, Long domainId, AuthType authType); + List getModelListWithAuth(User user, Long domainId, AuthType authType); - List getModelAuthList(String userName, AuthType authTypeEnum); + List getModelAuthList(User user, AuthType authTypeEnum); List getModelByDomainIds(List domainIds); diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java index ceb2fad56..098e63945 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDO.java @@ -3,116 +3,247 @@ package com.tencent.supersonic.semantic.model.domain.dataobject; import java.util.Date; public class MetricDO { - + /** + * + */ private Long id; + /** + * 主体域ID + */ private Long modelId; + /** + * 指标名称 + */ private String name; + /** + * 字段名称 + */ private String bizName; + /** + * 描述 + */ private String description; + /** + * 指标状态,0正常,1下架,2删除 + */ private Integer status; + /** + * 敏感级别 + */ private Integer sensitiveLevel; + /** + * 指标类型 proxy,expr + */ private String type; + /** + * 创建时间 + */ private Date createdAt; + /** + * 创建人 + */ private String createdBy; + /** + * 更新时间 + */ private Date updatedAt; + /** + * 更新人 + */ private String updatedBy; + /** + * 数值类型 + */ private String dataFormatType; + /** + * 数值类型参数 + */ private String dataFormat; + /** + * + */ private String alias; + /** + * + */ + private String tags; + + /** + * 类型参数 + */ private String typeParams; - + /** + * + * @return id + */ public Long getId() { return id; } + /** + * + * @param id + */ public void setId(Long id) { this.id = id; } + /** + * 主体域ID + * @return model_id 主体域ID + */ public Long getModelId() { return modelId; } + /** + * 主体域ID + * @param modelId 主体域ID + */ public void setModelId(Long modelId) { this.modelId = modelId; } + /** + * 指标名称 + * @return name 指标名称 + */ public String getName() { return name; } + /** + * 指标名称 + * @param name 指标名称 + */ public void setName(String name) { this.name = name == null ? null : name.trim(); } + /** + * 字段名称 + * @return biz_name 字段名称 + */ public String getBizName() { return bizName; } + /** + * 字段名称 + * @param bizName 字段名称 + */ public void setBizName(String bizName) { this.bizName = bizName == null ? null : bizName.trim(); } + /** + * 描述 + * @return description 描述 + */ public String getDescription() { return description; } + /** + * 描述 + * @param description 描述 + */ public void setDescription(String description) { this.description = description == null ? null : description.trim(); } + /** + * 指标状态,0正常,1下架,2删除 + * @return status 指标状态,0正常,1下架,2删除 + */ public Integer getStatus() { return status; } + /** + * 指标状态,0正常,1下架,2删除 + * @param status 指标状态,0正常,1下架,2删除 + */ public void setStatus(Integer status) { this.status = status; } + /** + * 敏感级别 + * @return sensitive_level 敏感级别 + */ public Integer getSensitiveLevel() { return sensitiveLevel; } + /** + * 敏感级别 + * @param sensitiveLevel 敏感级别 + */ public void setSensitiveLevel(Integer sensitiveLevel) { this.sensitiveLevel = sensitiveLevel; } + /** + * 指标类型 proxy,expr + * @return type 指标类型 proxy,expr + */ public String getType() { return type; } + /** + * 指标类型 proxy,expr + * @param type 指标类型 proxy,expr + */ public void setType(String type) { this.type = type == null ? null : type.trim(); } + /** + * 创建时间 + * @return created_at 创建时间 + */ public Date getCreatedAt() { return createdAt; } + /** + * 创建时间 + * @param createdAt 创建时间 + */ public void setCreatedAt(Date createdAt) { this.createdAt = createdAt; } + /** + * 创建人 + * @return created_by 创建人 + */ public String getCreatedBy() { return createdBy; } + /** + * 创建人 + * @param createdBy 创建人 + */ public void setCreatedBy(String createdBy) { this.createdBy = createdBy == null ? null : createdBy.trim(); } @@ -182,21 +313,37 @@ public class MetricDO { } /** - * - * @return alias + * + * @return alias */ public String getAlias() { return alias; } /** - * - * @param alias + * + * @param alias */ public void setAlias(String alias) { this.alias = alias == null ? null : alias.trim(); } + /** + * + * @return tags + */ + public String getTags() { + return tags; + } + + /** + * + * @param tags + */ + public void setTags(String tags) { + this.tags = tags == null ? null : tags.trim(); + } + /** * 类型参数 * @return type_params 类型参数 @@ -212,4 +359,4 @@ public class MetricDO { public void setTypeParams(String typeParams) { this.typeParams = typeParams == null ? null : typeParams.trim(); } -} +} \ No newline at end of file diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java index a74ee01eb..d57855cfe 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/dataobject/MetricDOExample.java @@ -31,6 +31,7 @@ public class MetricDOExample { protected Integer limitEnd; /** + * * @mbg.generated */ public MetricDOExample() { @@ -38,6 +39,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setOrderByClause(String orderByClause) { @@ -45,6 +47,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public String getOrderByClause() { @@ -52,6 +55,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setDistinct(boolean distinct) { @@ -59,6 +63,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public boolean isDistinct() { @@ -66,6 +71,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public List getOredCriteria() { @@ -73,6 +79,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void or(Criteria criteria) { @@ -80,6 +87,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public Criteria or() { @@ -89,6 +97,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public Criteria createCriteria() { @@ -100,6 +109,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ protected Criteria createCriteriaInternal() { @@ -108,6 +118,7 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void clear() { @@ -117,13 +128,15 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setLimitStart(Integer limitStart) { - this.limitStart = limitStart; + this.limitStart=limitStart; } /** + * * @mbg.generated */ public Integer getLimitStart() { @@ -131,13 +144,15 @@ public class MetricDOExample { } /** + * * @mbg.generated */ public void setLimitEnd(Integer limitEnd) { - this.limitEnd = limitEnd; + this.limitEnd=limitEnd; } /** + * * @mbg.generated */ public Integer getLimitEnd() { @@ -1177,6 +1192,76 @@ public class MetricDOExample { addCriterion("alias not between", value1, value2, "alias"); return (Criteria) this; } + + public Criteria andTagsIsNull() { + addCriterion("tags is null"); + return (Criteria) this; + } + + public Criteria andTagsIsNotNull() { + addCriterion("tags is not null"); + return (Criteria) this; + } + + public Criteria andTagsEqualTo(String value) { + addCriterion("tags =", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotEqualTo(String value) { + addCriterion("tags <>", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsGreaterThan(String value) { + addCriterion("tags >", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsGreaterThanOrEqualTo(String value) { + addCriterion("tags >=", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsLessThan(String value) { + addCriterion("tags <", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsLessThanOrEqualTo(String value) { + addCriterion("tags <=", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsLike(String value) { + addCriterion("tags like", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotLike(String value) { + addCriterion("tags not like", value, "tags"); + return (Criteria) this; + } + + public Criteria andTagsIn(List values) { + addCriterion("tags in", values, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotIn(List values) { + addCriterion("tags not in", values, "tags"); + return (Criteria) this; + } + + public Criteria andTagsBetween(String value1, String value2) { + addCriterion("tags between", value1, value2, "tags"); + return (Criteria) this; + } + + public Criteria andTagsNotBetween(String value1, String value2) { + addCriterion("tags not between", value1, value2, "tags"); + return (Criteria) this; + } } /** @@ -1209,6 +1294,38 @@ public class MetricDOExample { private String typeHandler; + public String getCondition() { + return condition; + } + + public Object getValue() { + return value; + } + + public Object getSecondValue() { + return secondValue; + } + + public boolean isNoValue() { + return noValue; + } + + public boolean isSingleValue() { + return singleValue; + } + + public boolean isBetweenValue() { + return betweenValue; + } + + public boolean isListValue() { + return listValue; + } + + public String getTypeHandler() { + return typeHandler; + } + protected Criterion(String condition) { super(); this.condition = condition; @@ -1244,37 +1361,5 @@ public class MetricDOExample { protected Criterion(String condition, Object value, Object secondValue) { this(condition, value, secondValue, null); } - - public String getCondition() { - return condition; - } - - public Object getValue() { - return value; - } - - public Object getSecondValue() { - return secondValue; - } - - public boolean isNoValue() { - return noValue; - } - - public boolean isSingleValue() { - return singleValue; - } - - public boolean isBetweenValue() { - return betweenValue; - } - - public boolean isListValue() { - return listValue; - } - - public String getTypeHandler() { - return typeHandler; - } } -} +} \ No newline at end of file diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java index 771915de1..cfdc4917d 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/pojo/Metric.java @@ -1,10 +1,12 @@ package com.tencent.supersonic.semantic.model.domain.pojo; - import com.tencent.supersonic.common.pojo.DataFormat; import com.tencent.supersonic.semantic.api.model.pojo.MetricTypeParams; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import lombok.Data; +import org.apache.commons.lang3.StringUtils; +import org.springframework.util.CollectionUtils; +import java.util.List; @Data public class Metric extends SchemaItem { @@ -23,4 +25,13 @@ public class Metric extends SchemaItem { private String alias; + private List tags; + + public String getTag() { + if (CollectionUtils.isEmpty(tags)) { + return ""; + } + return StringUtils.join(tags, ","); + } + } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java index 9b6513d3a..5573c7b87 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/utils/MetricConverter.java @@ -37,6 +37,7 @@ public class MetricConverter { if (metric.getDataFormat() != null) { metricDO.setDataFormat(JSONObject.toJSONString(metric.getDataFormat())); } + metricDO.setTags(metric.getTag()); return metricDO; } @@ -51,27 +52,23 @@ public class MetricConverter { BeanUtils.copyProperties(metric, metricDO); metricDO.setTypeParams(JSONObject.toJSONString(metric.getTypeParams())); metricDO.setDataFormat(JSONObject.toJSONString(metric.getDataFormat())); + metricDO.setTags(metric.getTag()); return metricDO; } - public static MetricResp convert2MetricDesc(MetricDO metricDO, Map modelMap) { - MetricResp metricDesc = new MetricResp(); - BeanUtils.copyProperties(metricDO, metricDesc); - metricDesc.setTypeParams(JSONObject.parseObject(metricDO.getTypeParams(), MetricTypeParams.class)); - metricDesc.setDataFormat(JSONObject.parseObject(metricDO.getDataFormat(), DataFormat.class)); + public static MetricResp convert2MetricResp(MetricDO metricDO, Map modelMap) { + MetricResp metricResp = new MetricResp(); + BeanUtils.copyProperties(metricDO, metricResp); + metricResp.setTypeParams(JSONObject.parseObject(metricDO.getTypeParams(), MetricTypeParams.class)); + metricResp.setDataFormat(JSONObject.parseObject(metricDO.getDataFormat(), DataFormat.class)); ModelResp modelResp = modelMap.get(metricDO.getModelId()); if (modelResp != null) { - metricDesc.setModelName(modelResp.getName()); + metricResp.setModelName(modelResp.getName()); + metricResp.setDomainId(modelResp.getDomainId()); } - return metricDesc; - } - - public static Metric convert2Metric(MetricDO metricDO) { - Metric metric = new Metric(); - BeanUtils.copyProperties(metricDO, metric); - metric.setTypeParams(JSONObject.parseObject(metricDO.getTypeParams(), MetricTypeParams.class)); - return metric; + metricResp.setTag(metricDO.getTags()); + return metricResp; } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java index 99a164695..4b60bfab2 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/MetricController.java @@ -9,6 +9,7 @@ import com.tencent.supersonic.semantic.api.model.request.PageMetricReq; import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.model.domain.MetricService; import java.util.List; +import java.util.Set; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -68,8 +69,11 @@ public class MetricController { @PostMapping("/queryMetric") - public PageInfo queryMetric(@RequestBody PageMetricReq pageMetrricReq) { - return metricService.queryMetric(pageMetrricReq); + public PageInfo queryMetric(@RequestBody PageMetricReq pageMetricReq, + HttpServletRequest request, + HttpServletResponse response) { + User user = UserHolder.findUser(request, response); + return metricService.queryMetric(pageMetricReq, user); } @GetMapping("getMetric/{modelId}/{bizName}") @@ -90,4 +94,9 @@ public class MetricController { } + @GetMapping("/getMetricTags") + public Set getMetricTags() { + return metricService.getMetricTags(); + } + } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java index 0eb0b5175..3c3ec3624 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/rest/ModelController.java @@ -60,7 +60,7 @@ public class ModelController { HttpServletRequest request, HttpServletResponse response) { User user = UserHolder.findUser(request, response); - return modelService.getModelListWithAuth(user.getName(), domainId, AuthType.ADMIN); + return modelService.getModelListWithAuth(user, domainId, AuthType.ADMIN); } diff --git a/semantic/model/src/main/resources/mapper/MetricDOMapper.xml b/semantic/model/src/main/resources/mapper/MetricDOMapper.xml index 37f09c2b5..2d20f9ee8 100644 --- a/semantic/model/src/main/resources/mapper/MetricDOMapper.xml +++ b/semantic/model/src/main/resources/mapper/MetricDOMapper.xml @@ -17,6 +17,7 @@ + @@ -52,7 +53,7 @@ id, model_id, name, biz_name, description, status, sensitive_level, type, created_at, - created_by, updated_at, updated_by, data_format_type, data_format, alias + created_by, updated_at, updated_by, data_format_type, data_format, alias, tags type_params @@ -108,13 +109,13 @@ sensitive_level, type, created_at, created_by, updated_at, updated_by, data_format_type, data_format, alias, - type_params) + tags, type_params) values (#{id,jdbcType=BIGINT}, #{modelId,jdbcType=BIGINT}, #{name,jdbcType=VARCHAR}, #{bizName,jdbcType=VARCHAR}, #{description,jdbcType=VARCHAR}, #{status,jdbcType=INTEGER}, #{sensitiveLevel,jdbcType=INTEGER}, #{type,jdbcType=VARCHAR}, #{createdAt,jdbcType=TIMESTAMP}, #{createdBy,jdbcType=VARCHAR}, #{updatedAt,jdbcType=TIMESTAMP}, #{updatedBy,jdbcType=VARCHAR}, #{dataFormatType,jdbcType=VARCHAR}, #{dataFormat,jdbcType=VARCHAR}, #{alias,jdbcType=VARCHAR}, - #{typeParams,jdbcType=LONGVARCHAR}) + #{tags,jdbcType=VARCHAR}, #{typeParams,jdbcType=LONGVARCHAR}) insert into s2_metric @@ -164,6 +165,9 @@ alias, + + tags, + type_params, @@ -214,6 +218,9 @@ #{alias,jdbcType=VARCHAR}, + + #{tags,jdbcType=VARCHAR}, + #{typeParams,jdbcType=LONGVARCHAR}, @@ -270,6 +277,9 @@ alias = #{alias,jdbcType=VARCHAR}, + + tags = #{tags,jdbcType=VARCHAR}, + type_params = #{typeParams,jdbcType=LONGVARCHAR}, @@ -292,6 +302,7 @@ data_format_type = #{dataFormatType,jdbcType=VARCHAR}, data_format = #{dataFormat,jdbcType=VARCHAR}, alias = #{alias,jdbcType=VARCHAR}, + tags = #{tags,jdbcType=VARCHAR}, type_params = #{typeParams,jdbcType=LONGVARCHAR} where id = #{id,jdbcType=BIGINT} @@ -310,7 +321,8 @@ updated_by = #{updatedBy,jdbcType=VARCHAR}, data_format_type = #{dataFormatType,jdbcType=VARCHAR}, data_format = #{dataFormat,jdbcType=VARCHAR}, - alias = #{alias,jdbcType=VARCHAR} + alias = #{alias,jdbcType=VARCHAR}, + tags = #{tags,jdbcType=VARCHAR} where id = #{id,jdbcType=BIGINT} \ No newline at end of file diff --git a/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml b/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml index 67db20d9c..8546d856b 100644 --- a/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml +++ b/semantic/model/src/main/resources/mapper/custom/MetricDOCustomMapper.xml @@ -2,22 +2,26 @@ - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - + + @@ -51,12 +55,11 @@ - id - , model_id, name, biz_name, description, type, created_at, created_by, updated_at, - updated_by + id, model_id, name, biz_name, description, status, sensitive_level, type, created_at, + created_by, updated_at, updated_by, data_format_type, data_format, alias, tags - typeParams + type_params @@ -108,7 +111,8 @@ and ( id like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or name like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or biz_name like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or - description like CONCAT('%',#{key , jdbcType=VARCHAR},'%') ) + description like CONCAT('%',#{key , jdbcType=VARCHAR},'%') or + tags like CONCAT('%',#{key , jdbcType=VARCHAR},'%') ) and id like CONCAT('%',#{id , jdbcType=VARCHAR},'%') diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java index ef77abbe5..eafe6c0d8 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/rest/QueryController.java @@ -2,18 +2,22 @@ package com.tencent.supersonic.semantic.query.rest; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.auth.api.authentication.utils.UserHolder; +import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.model.response.SqlParserResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; +import com.tencent.supersonic.semantic.api.query.request.ItemUseReq; +import com.tencent.supersonic.semantic.api.query.request.ParseSqlReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; -import com.tencent.supersonic.semantic.api.query.request.ParseSqlReq; -import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; -import com.tencent.supersonic.semantic.api.query.request.ItemUseReq; +import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.response.ItemUseResp; -import com.tencent.supersonic.semantic.query.service.SemanticQueryEngine; -import com.tencent.supersonic.semantic.query.service.QueryService; import com.tencent.supersonic.semantic.query.persistence.pojo.QueryStatement; +import com.tencent.supersonic.semantic.query.service.QueryService; +import com.tencent.supersonic.semantic.query.service.SemanticQueryEngine; import java.util.List; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -89,10 +93,36 @@ public class QueryController { @PostMapping("/queryDimValue") public QueryResultWithSchemaResp queryDimValue(@RequestBody QueryDimValueReq queryDimValueReq, - HttpServletRequest request, - HttpServletResponse response) { + HttpServletRequest request, + HttpServletResponse response) { User user = UserHolder.findUser(request, response); return queryService.queryDimValue(queryDimValueReq, user); } + @PostMapping("/explain") + public ExplainResp explain(@RequestBody ExplainSqlReq explainSqlReq, + HttpServletRequest request, + HttpServletResponse response) throws Exception { + + User user = UserHolder.findUser(request, response); + String queryReqJson = JsonUtil.toString(explainSqlReq.getQueryReq()); + QueryTypeEnum queryTypeEnum = explainSqlReq.getQueryTypeEnum(); + + if (QueryTypeEnum.SQL.equals(queryTypeEnum)) { + QueryDslReq queryDslReq = JsonUtil.toObject(queryReqJson, QueryDslReq.class); + ExplainSqlReq explainSqlReqNew = ExplainSqlReq.builder() + .queryReq(queryDslReq) + .queryTypeEnum(queryTypeEnum).build(); + return queryService.explain(explainSqlReqNew, user); + } + if (QueryTypeEnum.STRUCT.equals(queryTypeEnum)) { + QueryStructReq queryStructReq = JsonUtil.toObject(queryReqJson, QueryStructReq.class); + ExplainSqlReq explainSqlReqNew = ExplainSqlReq.builder() + .queryReq(queryStructReq) + .queryTypeEnum(queryTypeEnum).build(); + return queryService.explain(explainSqlReqNew, user); + } + return null; + } + } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java index 28b04ed95..8e8a6e4c8 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryService.java @@ -1,11 +1,13 @@ package com.tencent.supersonic.semantic.query.service; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.ItemUseReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; -import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; +import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.response.ItemUseResp; import java.util.List; @@ -25,4 +27,5 @@ public interface QueryService { List getStatInfo(ItemUseReq itemUseCommend); + ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception; } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java index 1b94f9651..298a0e897 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java @@ -6,12 +6,15 @@ import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; import com.tencent.supersonic.common.util.cache.CacheUtils; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; import com.tencent.supersonic.semantic.api.model.request.ModelSchemaFilterReq; +import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.pojo.Cache; import com.tencent.supersonic.semantic.api.query.pojo.Filter; +import com.tencent.supersonic.semantic.api.query.request.ExplainSqlReq; import com.tencent.supersonic.semantic.api.query.request.ItemUseReq; import com.tencent.supersonic.semantic.api.query.request.QueryDimValueReq; import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; @@ -64,6 +67,14 @@ public class QueryServiceImpl implements QueryService { @Override public Object queryBySql(QueryDslReq querySqlCmd, User user) throws Exception { + statUtils.initStatInfo(querySqlCmd, user); + QueryStatement queryStatement = convertToQueryStatement(querySqlCmd, user); + QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement); + statUtils.statInfo2DbAsync(TaskStatusEnum.SUCCESS); + return results; + } + + private QueryStatement convertToQueryStatement(QueryDslReq querySqlCmd, User user) throws Exception { ModelSchemaFilterReq filter = new ModelSchemaFilterReq(); List modelIds = new ArrayList<>(); modelIds.add(querySqlCmd.getModelId()); @@ -74,7 +85,7 @@ public class QueryServiceImpl implements QueryService { QueryStatement queryStatement = queryReqConverter.convert(querySqlCmd, domainSchemas); queryStatement.setModelId(querySqlCmd.getModelId()); - return semanticQueryEngine.execute(queryStatement); + return queryStatement; } @Override @@ -183,6 +194,32 @@ public class QueryServiceImpl implements QueryService { return statInfos; } + @Override + public ExplainResp explain(ExplainSqlReq explainSqlReq, User user) throws Exception { + QueryTypeEnum queryTypeEnum = explainSqlReq.getQueryTypeEnum(); + T queryReq = explainSqlReq.getQueryReq(); + + if (QueryTypeEnum.SQL.equals(queryTypeEnum) && queryReq instanceof QueryDslReq) { + QueryStatement queryStatement = convertToQueryStatement((QueryDslReq) queryReq, user); + return getExplainResp(queryStatement); + } + if (QueryTypeEnum.STRUCT.equals(queryTypeEnum) && queryReq instanceof QueryStructReq) { + QueryStatement queryStatement = semanticQueryEngine.plan((QueryStructReq) queryReq); + return getExplainResp(queryStatement); + } + + throw new IllegalArgumentException("Parameters are invalid, explainSqlReq: " + explainSqlReq); + } + + private ExplainResp getExplainResp(QueryStatement queryStatement) { + String sql = ""; + if (Objects.nonNull(queryStatement)) { + sql = queryStatement.getSql(); + } + return ExplainResp.builder().sql(sql).build(); + } + + private boolean isCache(QueryStructReq queryStructCmd) { if (!cacheEnable) { return false; diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java index 68058404b..109aa779c 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/SchemaServiceImpl.java @@ -55,7 +55,10 @@ public class SchemaServiceImpl implements SchemaService { @Override public List fetchModelSchema(ModelSchemaFilterReq filter, User user) { List domainSchemaDescList = modelService.fetchModelSchema(filter); - List statInfos = queryService.getStatInfo(new ItemUseReq()); + ItemUseReq itemUseCommend = new ItemUseReq(); + itemUseCommend.setModelIds(filter.getModelIds()); + + List statInfos = queryService.getStatInfo(itemUseCommend); log.debug("statInfos:{}", statInfos); fillCnt(domainSchemaDescList, statInfos); return domainSchemaDescList; @@ -116,7 +119,7 @@ public class SchemaServiceImpl implements SchemaService { @Override public PageInfo queryMetric(PageMetricReq pageMetricCmd, User user) { - return metricService.queryMetric(pageMetricCmd); + return metricService.queryMetric(pageMetricCmd, user); } @Override @@ -126,7 +129,7 @@ public class SchemaServiceImpl implements SchemaService { @Override public List getModelList(User user, AuthType authTypeEnum, Long domainId) { - return modelService.getModelListWithAuth(user.getName(), domainId, authTypeEnum); + return modelService.getModelListWithAuth(user, domainId, authTypeEnum); } } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java index 36afdef4e..dd0637795 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DataPermissionAOP.java @@ -140,7 +140,7 @@ public class DataPermissionAOP { private boolean doModelAdmin(User user, QueryStructReq queryStructReq) { Long modelId = queryStructReq.getModelId(); - List modelListAdmin = modelService.getModelListWithAuth(user.getName(), null, AuthType.ADMIN); + List modelListAdmin = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); if (CollectionUtils.isEmpty(modelListAdmin)) { return false; } else { @@ -153,7 +153,7 @@ public class DataPermissionAOP { private void doModelVisible(User user, QueryStructReq queryStructReq) { Boolean visible = true; Long modelId = queryStructReq.getModelId(); - List modelListVisible = modelService.getModelListWithAuth(user.getName(), null, AuthType.VISIBLE); + List modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VISIBLE); if (CollectionUtils.isEmpty(modelListVisible)) { visible = false; } else { diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/StatUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/StatUtils.java index 0a6c8a5ca..a009750a3 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/StatUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/StatUtils.java @@ -4,22 +4,30 @@ import com.alibaba.ttl.TransmittableThreadLocal; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.semantic.api.model.enums.QueryTypeBackEnum; import com.tencent.supersonic.semantic.api.model.enums.QueryTypeEnum; import com.tencent.supersonic.semantic.api.model.pojo.QueryStat; +import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; +import com.tencent.supersonic.semantic.api.model.response.ModelSchemaResp; import com.tencent.supersonic.semantic.api.query.request.ItemUseReq; +import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.response.ItemUseResp; -import com.tencent.supersonic.common.pojo.enums.TaskStatusEnum; +import com.tencent.supersonic.semantic.model.domain.ModelService; import com.tencent.supersonic.semantic.query.persistence.repository.StatRepository; import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Set; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; import org.apache.commons.codec.digest.DigestUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; @Component @Slf4j @@ -28,13 +36,17 @@ public class StatUtils { private static final TransmittableThreadLocal STATS = new TransmittableThreadLocal<>(); private final StatRepository statRepository; private final SqlFilterUtils sqlFilterUtils; + + private final ModelService modelService; private final ObjectMapper objectMapper = new ObjectMapper(); public StatUtils(StatRepository statRepository, - SqlFilterUtils sqlFilterUtils) { + SqlFilterUtils sqlFilterUtils, + ModelService modelService) { this.statRepository = statRepository; this.sqlFilterUtils = sqlFilterUtils; + this.modelService = modelService; } public static QueryStat get() { @@ -69,6 +81,44 @@ public class StatUtils { return true; } + + public void initStatInfo(QueryDslReq queryDslReq, User facadeUser) { + QueryStat queryStatInfo = new QueryStat(); + List allFields = SqlParserSelectHelper.getAllFields(queryDslReq.getSql()); + queryStatInfo.setModelId(queryDslReq.getModelId()); + ModelSchemaResp modelSchemaResp = modelService.fetchSingleModelSchema(queryDslReq.getModelId()); + + List dimensions = new ArrayList<>(); + if (Objects.nonNull(modelSchemaResp)) { + dimensions = getFieldNames(allFields, modelSchemaResp.getDimensions()); + } + + List metrics = new ArrayList<>(); + if (Objects.nonNull(modelSchemaResp)) { + metrics = getFieldNames(allFields, modelSchemaResp.getMetrics()); + } + + String userName = getUserName(facadeUser); + try { + queryStatInfo.setTraceId("") + .setModelId(queryDslReq.getModelId()) + .setUser(userName) + .setQueryType(QueryTypeEnum.SQL.getValue()) + .setQueryTypeBack(QueryTypeBackEnum.NORMAL.getState()) + .setQuerySqlCmd(queryDslReq.toString()) + .setQuerySqlCmdMd5(DigestUtils.md5Hex(queryDslReq.toString())) + .setStartTime(System.currentTimeMillis()) + .setUseResultCache(true) + .setUseSqlCache(true) + .setMetrics(objectMapper.writeValueAsString(metrics)) + .setDimensions(objectMapper.writeValueAsString(dimensions)); + } catch (JsonProcessingException e) { + log.error("initStatInfo:{}", e); + } + StatUtils.set(queryStatInfo); + + } + public void initStatInfo(QueryStructReq queryStructCmd, User facadeUser) { QueryStat queryStatInfo = new QueryStat(); String traceId = ""; @@ -76,12 +126,11 @@ public class StatUtils { List metrics = new ArrayList<>(); queryStructCmd.getAggregators().stream().forEach(aggregator -> metrics.add(aggregator.getColumn())); - String user = (Objects.nonNull(facadeUser) && Strings.isNotEmpty(facadeUser.getName())) ? facadeUser.getName() - : "Admin"; + String user = getUserName(facadeUser); try { queryStatInfo.setTraceId(traceId) - .setClassId(queryStructCmd.getModelId()) + .setModelId(queryStructCmd.getModelId()) .setUser(user) .setQueryType(QueryTypeEnum.STRUCT.getValue()) .setQueryTypeBack(QueryTypeBackEnum.NORMAL.getState()) @@ -105,6 +154,25 @@ public class StatUtils { } + private List getFieldNames(List allFields, List schemaItems) { + Set fieldNames = schemaItems + .stream() + .map(dimSchemaResp -> dimSchemaResp.getBizName()) + .collect(Collectors.toSet()); + if (!CollectionUtils.isEmpty(fieldNames)) { + return allFields.stream().filter(fieldName -> fieldNames.contains(fieldName)) + .collect(Collectors.toList()); + } + return new ArrayList<>(); + } + + private String getUserName(User facadeUser) { + return (Objects.nonNull(facadeUser) && Strings.isNotEmpty(facadeUser.getName())) ? facadeUser.getName() + : "Admin"; + } + + + public List getStatInfo(ItemUseReq itemUseCommend) { return statRepository.getStatInfo(itemUseCommend); } diff --git a/semantic/query/src/main/resources/mapper/StatMapper.xml b/semantic/query/src/main/resources/mapper/StatMapper.xml index a67f1729d..dec96b377 100644 --- a/semantic/query/src/main/resources/mapper/StatMapper.xml +++ b/semantic/query/src/main/resources/mapper/StatMapper.xml @@ -64,6 +64,12 @@ and model_id = #{modelId} + + and model_id in + + #{id} + + and metrics like concat('%',#{metric},'%') diff --git a/webapp/.gitignore b/webapp/.gitignore index 0102f4656..679cdeb79 100644 --- a/webapp/.gitignore +++ b/webapp/.gitignore @@ -19,7 +19,6 @@ supersonic-webapp.tar.gz package-lock.json yarn.lock -pnpm-lock.yaml # misc .DS_Store diff --git a/webapp/packages/chat-sdk/.gitignore b/webapp/packages/chat-sdk/.gitignore index 149d63661..bea504c15 100644 --- a/webapp/packages/chat-sdk/.gitignore +++ b/webapp/packages/chat-sdk/.gitignore @@ -13,8 +13,6 @@ /dist -pnpm-lock.yaml - # misc .DS_Store .env.local diff --git a/webapp/packages/supersonic-fe/.gitignore b/webapp/packages/supersonic-fe/.gitignore index a121849a0..8e11232af 100644 --- a/webapp/packages/supersonic-fe/.gitignore +++ b/webapp/packages/supersonic-fe/.gitignore @@ -20,7 +20,6 @@ yarn-error.log .idea yarn.lock package-lock.json -pnpm-lock.yaml *bak .vscode diff --git a/webapp/packages/supersonic-fe/package.json b/webapp/packages/supersonic-fe/package.json index f63f686bb..3054f6432 100644 --- a/webapp/packages/supersonic-fe/package.json +++ b/webapp/packages/supersonic-fe/package.json @@ -76,7 +76,7 @@ "antd": "^4.24.8", "classnames": "^2.2.6", "copy-to-clipboard": "^3.3.1", - "cross-env": "^7.0.0", + "cross-env": "^7.0.3", "crypto-js": "^4.0.0", "echarts": "^5.0.2", "echarts-for-react": "^3.0.1", diff --git a/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/components/MetricFilter.tsx b/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/components/MetricFilter.tsx index fde5dc890..c83e22c1d 100644 --- a/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/components/MetricFilter.tsx +++ b/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/components/MetricFilter.tsx @@ -81,7 +81,7 @@ const MetricFilter: React.FC = ({ filterValues = {}, onFiltersChange }) =
} onSearch={(value) => { onSearch(value); diff --git a/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/index.tsx b/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/index.tsx index ca089637d..609dbd756 100644 --- a/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/index.tsx +++ b/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/index.tsx @@ -1,9 +1,9 @@ import type { ActionType, ProColumns } from '@ant-design/pro-table'; import ProTable from '@ant-design/pro-table'; -import { message, Space, Popconfirm } from 'antd'; +import { message, Space, Popconfirm, Tag } from 'antd'; import React, { useRef, useState, useEffect } from 'react'; import type { Dispatch } from 'umi'; -import { connect } from 'umi'; +import { connect, history } from 'umi'; import type { StateType } from '../model'; import { SENSITIVE_LEVEL_ENUM } from '../constant'; import { queryMetric, deleteMetric } from '../service'; @@ -89,6 +89,20 @@ const ClassMetricTable: React.FC = ({ domainManger, dispatch }) => { { dataIndex: 'name', title: '指标名称', + render: (_, record: any) => { + if (record.hasAdminRes) { + return ( + { + history.replace(`/model/${record.domainId}/${record.modelId}/metric`); + }} + > + {record.name} + + ); + } + return <> {record.name}; + }, }, // { // dataIndex: 'alias', @@ -113,6 +127,25 @@ const ClassMetricTable: React.FC = ({ domainManger, dispatch }) => { title: '创建人', search: false, }, + { + dataIndex: 'tags', + title: '标签', + search: false, + render: (tags) => { + if (Array.isArray(tags)) { + return ( + + {tags.map((tag) => ( + + {tag} + + ))} + + ); + } + return <>--; + }, + }, { dataIndex: 'description', title: '描述', @@ -140,43 +173,47 @@ const ClassMetricTable: React.FC = ({ domainManger, dispatch }) => { dataIndex: 'x', valueType: 'option', render: (_, record) => { - return ( - - { - setMetricItem(record); - setCreateModalVisible(true); - }} - > - 编辑 - - - { - const { code, msg } = await deleteMetric(record.id); - if (code === 200) { - setMetricItem(undefined); - actionRef.current?.reload(); - } else { - message.error(msg); - } - }} - > + if (record.hasAdminRes) { + return ( + { setMetricItem(record); + setCreateModalVisible(true); }} > - 删除 + 编辑 - - - ); + + { + const { code, msg } = await deleteMetric(record.id); + if (code === 200) { + setMetricItem(undefined); + queryMetricList(); + } else { + message.error(msg); + } + }} + > + { + setMetricItem(record); + }} + > + 删除 + + + + ); + } else { + return <>; + } }, }, ]; @@ -239,7 +276,7 @@ const ClassMetricTable: React.FC = ({ domainManger, dispatch }) => { metricItem={metricItem} onSubmit={() => { setCreateModalVisible(false); - actionRef?.current?.reload(); + queryMetricList(); dispatch({ type: 'domainManger/queryMetricList', payload: { diff --git a/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/style.less b/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/style.less index df680eae5..df11043b0 100644 --- a/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/style.less +++ b/webapp/packages/supersonic-fe/src/pages/SemanticModel/Metric/style.less @@ -15,7 +15,7 @@ // margin-bottom: 12px; background: #fff; border-radius: 10px; - width: 500px; + width: 540px; margin: 0 auto; .searchInput { width: 100%; diff --git a/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/ClassMetricTable.tsx b/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/ClassMetricTable.tsx index 1826d83b3..65a447aca 100644 --- a/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/ClassMetricTable.tsx +++ b/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/ClassMetricTable.tsx @@ -1,6 +1,6 @@ import type { ActionType, ProColumns } from '@ant-design/pro-table'; import ProTable from '@ant-design/pro-table'; -import { message, Button, Space, Popconfirm, Input } from 'antd'; +import { message, Button, Space, Popconfirm, Input, Tag } from 'antd'; import React, { useRef, useState } from 'react'; import type { Dispatch } from 'umi'; import { connect } from 'umi'; @@ -76,7 +76,7 @@ const ClassMetricTable: React.FC = ({ domainManger, dispatch }) => { dataIndex: 'key', title: '指标搜索', hideInTable: true, - renderFormItem: () => , + renderFormItem: () => , }, { dataIndex: 'alias', @@ -101,6 +101,25 @@ const ClassMetricTable: React.FC = ({ domainManger, dispatch }) => { title: '创建人', search: false, }, + { + dataIndex: 'tags', + title: '标签', + search: false, + render: (tags) => { + if (Array.isArray(tags)) { + return ( + + {tags.map((tag) => ( + + {tag} + + ))} + + ); + } + return <>--; + }, + }, { dataIndex: 'description', title: '描述', diff --git a/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/MetricInfoCreateForm.tsx b/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/MetricInfoCreateForm.tsx index 130780a2b..95777067b 100644 --- a/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/MetricInfoCreateForm.tsx +++ b/webapp/packages/supersonic-fe/src/pages/SemanticModel/components/MetricInfoCreateForm.tsx @@ -24,7 +24,7 @@ import FormItemTitle from '@/components/FormHelper/FormItemTitle'; import styles from './style.less'; import { getMeasureListByModelId } from '../service'; import TableTitleTooltips from '../components/TableTitleTooltips'; -import { creatExprMetric, updateExprMetric, mockMetricAlias } from '../service'; +import { creatExprMetric, updateExprMetric, mockMetricAlias, getMetricTags } from '../service'; import { ISemantic } from '../data'; import { history } from 'umi'; @@ -75,6 +75,8 @@ const MetricInfoCreateForm: React.FC = ({ const [hasMeasuresState, setHasMeasuresState] = useState(true); const [llmLoading, setLlmLoading] = useState(false); + const [tagOptions, setTagOptions] = useState<{ label: string; value: string }[]>([]); + const forward = () => setCurrentStep(currentStep + 1); const backward = () => setCurrentStep(currentStep - 1); @@ -95,6 +97,7 @@ const MetricInfoCreateForm: React.FC = ({ useEffect(() => { queryClassMeasureList(); + queryMetricTags(); }, []); const handleNext = async () => { @@ -126,6 +129,7 @@ const MetricInfoCreateForm: React.FC = ({ dataFormat, dataFormatType, alias, + tags, } = metricItem as any; const isPercent = dataFormatType === 'percent'; const isDecimal = dataFormatType === 'decimal'; @@ -135,6 +139,7 @@ const MetricInfoCreateForm: React.FC = ({ bizName, sensitiveLevel, description, + tags, // isPercent, dataFormatType: dataFormatType || '', alias: alias && alias.trim() ? alias.split(',') : [], @@ -204,6 +209,22 @@ const MetricInfoCreateForm: React.FC = ({ } }; + const queryMetricTags = async () => { + const { code, data } = await getMetricTags(); + if (code === 200) { + // form.setFieldValue('alias', Array.from(new Set([...formAlias, ...data]))); + setTagOptions( + Array.isArray(data) + ? data.map((tag: string) => { + return { label: tag, value: tag }; + }) + : [], + ); + } else { + message.error('获取指标标签失败'); + } + }; + const renderContent = () => { if (currentStep === 1) { return ( @@ -277,6 +298,15 @@ const MetricInfoCreateForm: React.FC = ({ )} + +