diff --git a/.gitignore b/.gitignore index 6ad72f6b8..2f509cfc0 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ assembly/runtime/* /runtime **/.flattened-pom.xml chm_db/ -__pycache__/ \ No newline at end of file +__pycache__/ +/dict \ No newline at end of file diff --git a/assembly/bin/supersonic-build.sh b/assembly/bin/supersonic-build.sh index 3f0fc75e9..335147ffa 100755 --- a/assembly/bin/supersonic-build.sh +++ b/assembly/bin/supersonic-build.sh @@ -67,4 +67,4 @@ moveToRuntime standalone setEnvToWeb chat setEnvToWeb semantic -rm -fr ${buildDir}/webapp \ No newline at end of file +rm -fr ${buildDir}/webapp diff --git a/assembly/bin/supersonic-daemon.sh b/assembly/bin/supersonic-daemon.sh index 49dfa53bd..ae4e3eb37 100755 --- a/assembly/bin/supersonic-daemon.sh +++ b/assembly/bin/supersonic-daemon.sh @@ -192,4 +192,4 @@ case "$command" in *) echo "Use command {start|stop|restart} to run." exit 1 -esac \ No newline at end of file +esac diff --git a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java index ab361408d..5f9ed3dd2 100644 --- a/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java +++ b/auth/api/src/main/java/com/tencent/supersonic/auth/api/authorization/service/AuthService.java @@ -10,7 +10,7 @@ public interface AuthService { List queryAuthGroups(String domainId, Integer groupId); - void updateAuthGroup(AuthGroup group); + void addOrUpdateAuthGroup(AuthGroup group); void removeAuthGroup(AuthGroup group); diff --git a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java index 9053a8229..25f7ba3b5 100644 --- a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java +++ b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/application/AuthServiceImpl.java @@ -53,7 +53,7 @@ public class AuthServiceImpl implements AuthService { } @Override - public void updateAuthGroup(AuthGroup group) { + public void addOrUpdateAuthGroup(AuthGroup group) { Gson g = new Gson(); if (group.getGroupId() == null) { int nextGroupId = 1; diff --git a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java index 7629ca64c..cea9cbef8 100644 --- a/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java +++ b/auth/authorization/src/main/java/com/tencent/supersonic/auth/authorization/rest/AuthController.java @@ -40,7 +40,7 @@ public class AuthController { @PostMapping("/createGroup") public void newAuthGroup(@RequestBody AuthGroup group) { group.setGroupId(null); - authService.updateAuthGroup(group); + authService.addOrUpdateAuthGroup(group); } @PostMapping("/removeGroup") @@ -58,7 +58,7 @@ public class AuthController { if (group.getGroupId() == null || group.getGroupId() == 0) { throw new RuntimeException("groupId is empty"); } - authService.updateAuthGroup(group); + authService.addOrUpdateAuthGroup(group); } /** diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java index 4cf01216f..ebb29035c 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/SemanticSchema.java @@ -7,6 +7,7 @@ import java.util.Map; import java.util.stream.Collectors; public class SemanticSchema implements Serializable { + private List modelSchemaList; public SemanticSchema(List modelSchemaList) { @@ -34,12 +35,28 @@ public class SemanticSchema implements Serializable { return dimensions; } + public List getDimensions(Long modelId) { + List dimensions = getDimensions(); + return getElementsByModelId(modelId, dimensions); + } + public List getMetrics() { List metrics = new ArrayList<>(); modelSchemaList.stream().forEach(d -> metrics.addAll(d.getMetrics())); return metrics; } + public List getMetrics(Long modelId) { + List metrics = getMetrics(); + return getElementsByModelId(modelId, metrics); + } + + private List getElementsByModelId(Long modelId, List elements) { + return elements.stream() + .filter(schemaElement -> modelId.equals(schemaElement.getModel())) + .collect(Collectors.toList()); + } + public List getModels() { List models = new ArrayList<>(); modelSchemaList.stream().forEach(d -> models.add(d.getModel())); diff --git a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java index b93804ff0..af9a1425e 100644 --- a/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java +++ b/chat/api/src/main/java/com/tencent/supersonic/chat/api/pojo/request/QueryDataReq.java @@ -1,25 +1,20 @@ package com.tencent.supersonic.chat.api.pojo.request; +import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.common.pojo.DateConf; -import com.tencent.supersonic.common.pojo.Order; -import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; import java.util.HashSet; import java.util.Set; import lombok.Data; @Data public class QueryDataReq { - String queryMode; - SchemaElement model; - Set metrics = new HashSet<>(); - Set dimensions = new HashSet<>(); - Set dimensionFilters = new HashSet<>(); - Set metricFilters = new HashSet<>(); - private AggregateTypeEnum aggType = AggregateTypeEnum.NONE; - private Set orders = new HashSet<>(); + private User user; + private Set metrics = new HashSet<>(); + private Set dimensions = new HashSet<>(); + private Set dimensionFilters = new HashSet<>(); private DateConf dateInfo; - private Long limit; - private Boolean nativeQuery = false; + private Long queryId = 7L; + private Integer parseId = 2; } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java deleted file mode 100644 index f4221a8d8..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/DateFieldCorrector.java +++ /dev/null @@ -1,27 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import java.util.List; -import lombok.extern.slf4j.Slf4j; -import org.springframework.util.CollectionUtils; - -@Slf4j -public class DateFieldCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - - String sql = semanticCorrectInfo.getSql(); - List whereFields = SqlParserSelectHelper.getWhereFields(sql); - if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(DATE_FIELD)) { - String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); - sql = SqlParserUpdateHelper.addWhere(sql, DATE_FIELD, currentDate); - } - semanticCorrectInfo.setPreSql(semanticCorrectInfo.getSql()); - semanticCorrectInfo.setSql(sql); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java deleted file mode 100644 index 77cb01c3d..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldCorrector.java +++ /dev/null @@ -1,18 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class FieldCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - String preSql = semanticCorrectInfo.getSql(); - semanticCorrectInfo.setPreSql(preSql); - String sql = SqlParserUpdateHelper.replaceFields(preSql, - getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId())); - semanticCorrectInfo.setSql(sql); - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java deleted file mode 100644 index 7564942c4..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionAliasCorrector.java +++ /dev/null @@ -1,16 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class FunctionAliasCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); - semanticCorrectInfo.setSql(replaceAlias); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java deleted file mode 100644 index e0a3a3210..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FunctionCorrector.java +++ /dev/null @@ -1,17 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import lombok.extern.slf4j.Slf4j; - -@Slf4j -public class FunctionCorrector extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { - String preSql = semanticCorrectInfo.getSql(); - semanticCorrectInfo.setPreSql(preSql); - String sql = SqlParserUpdateHelper.replaceFunction(preSql); - semanticCorrectInfo.setSql(sql); - } -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldNameCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java similarity index 64% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldNameCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java index f94b98253..774e12aeb 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldNameCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GlobalCorrector.java @@ -16,11 +16,39 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.util.CollectionUtils; @Slf4j -public class FieldNameCorrector extends BaseSemanticCorrector { +public class GlobalCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { + replaceAlias(semanticCorrectInfo); + + updateFieldNameByLinkingValue(semanticCorrectInfo); + + updateFieldNameByBizName(semanticCorrectInfo); + + addAggregateToMetric(semanticCorrectInfo); + } + + private void addAggregateToMetric(SemanticCorrectInfo semanticCorrectInfo) { + + } + + private void replaceAlias(SemanticCorrectInfo semanticCorrectInfo) { + String replaceAlias = SqlParserUpdateHelper.replaceAlias(semanticCorrectInfo.getSql()); + semanticCorrectInfo.setSql(replaceAlias); + } + + private void updateFieldNameByBizName(SemanticCorrectInfo semanticCorrectInfo) { + + Map fieldToBizName = getFieldToBizName(semanticCorrectInfo.getParseInfo().getModelId()); + + String sql = SqlParserUpdateHelper.replaceFields(semanticCorrectInfo.getSql(), fieldToBizName); + + semanticCorrectInfo.setSql(sql); + } + + private void updateFieldNameByLinkingValue(SemanticCorrectInfo semanticCorrectInfo) { Object context = semanticCorrectInfo.getParseInfo().getProperties().get(Constants.CONTEXT); if (Objects.isNull(context)) { return; @@ -45,5 +73,4 @@ public class FieldNameCorrector extends BaseSemanticCorrector { String sql = SqlParserUpdateHelper.replaceFieldNameByValue(preSql, fieldValueToFieldNames); semanticCorrectInfo.setSql(sql); } - -} +} \ No newline at end of file diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java new file mode 100644 index 000000000..c931d2f0f --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/GroupByCorrector.java @@ -0,0 +1,15 @@ +package com.tencent.supersonic.chat.corrector; + +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class GroupByCorrector extends BaseSemanticCorrector { + + @Override + public void correct(SemanticCorrectInfo semanticCorrectInfo) { + + + + } +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java new file mode 100644 index 000000000..c5d8a514d --- /dev/null +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/HavingCorrector.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.chat.corrector; + +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +public class HavingCorrector extends BaseSemanticCorrector { + + @Override + public void correct(SemanticCorrectInfo semanticCorrectInfo) { + + } + +} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java deleted file mode 100644 index 4bb63515d..000000000 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/QueryFilterAppend.java +++ /dev/null @@ -1,48 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; -import com.tencent.supersonic.common.pojo.Constants; -import com.tencent.supersonic.common.util.StringUtil; -import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; -import java.util.Objects; -import java.util.stream.Collectors; -import lombok.extern.slf4j.Slf4j; -import net.sf.jsqlparser.JSQLParserException; -import net.sf.jsqlparser.expression.Expression; -import net.sf.jsqlparser.parser.CCJSqlParserUtil; -import org.apache.commons.collections.CollectionUtils; -import org.apache.commons.lang3.StringUtils; - -@Slf4j -public class QueryFilterAppend extends BaseSemanticCorrector { - - @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { - String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters()); - String preSql = semanticCorrectInfo.getSql(); - - if (StringUtils.isNotEmpty(queryFilter)) { - log.info("add queryFilter to preSql :{}", queryFilter); - Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); - String sql = SqlParserUpdateHelper.addWhere(preSql, expression); - semanticCorrectInfo.setPreSql(preSql); - semanticCorrectInfo.setSql(sql); - } - } - - private String getQueryFilter(QueryFilters queryFilters) { - if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { - return null; - } - return queryFilters.getFilters().stream() - .map(filter -> { - String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); - String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); - String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); - return bizNameWrap + operatorWrap + valueWrap; - }) - .collect(Collectors.joining(Constants.AND_UPPER)); - } - -} diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java similarity index 96% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java index 5476370fb..62407df25 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/SelectCorrector.java @@ -13,11 +13,12 @@ import net.sf.jsqlparser.expression.Expression; import org.springframework.util.CollectionUtils; @Slf4j -public class SelectFieldAppendCorrector extends BaseSemanticCorrector { +public class SelectCorrector extends BaseSemanticCorrector { @Override public void correct(SemanticCorrectInfo semanticCorrectInfo) { String preSql = semanticCorrectInfo.getSql(); + if (SqlParserSelectHelper.hasAggregateFunction(preSql)) { Expression havingExpression = SqlParserSelectHelper.getHavingExpression(preSql); if (Objects.nonNull(havingExpression)) { diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableNameCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableCorrector.java similarity index 91% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableNameCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableCorrector.java index 03f9b7ecb..1a64727c3 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableNameCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/TableCorrector.java @@ -5,7 +5,7 @@ import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import lombok.extern.slf4j.Slf4j; @Slf4j -public class TableNameCorrector extends BaseSemanticCorrector { +public class TableCorrector extends BaseSemanticCorrector { public static final String TABLE_PREFIX = "t_"; diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldValueCorrector.java b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java similarity index 50% rename from chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldValueCorrector.java rename to chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java index b660f8946..fc607194a 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/FieldValueCorrector.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/corrector/WhereCorrector.java @@ -1,26 +1,92 @@ package com.tencent.supersonic.chat.corrector; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SchemaElement; import com.tencent.supersonic.chat.api.pojo.SchemaValueMap; +import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; import com.tencent.supersonic.chat.api.pojo.SemanticSchema; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilters; +import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.util.ContextUtils; +import com.tencent.supersonic.common.util.StringUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.knowledge.service.SchemaService; +import com.tencent.supersonic.semantic.api.model.enums.TimeDimensionEnum; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.util.Strings; import org.springframework.util.CollectionUtils; @Slf4j -public class FieldValueCorrector extends BaseSemanticCorrector { +public class WhereCorrector extends BaseSemanticCorrector { @Override - public void correct(SemanticCorrectInfo semanticCorrectInfo) { + public void correct(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { + + addDateIfNotExist(semanticCorrectInfo); + + parserDateDiffFunction(semanticCorrectInfo); + + addQueryFilter(semanticCorrectInfo); + + updateFieldValueByTechName(semanticCorrectInfo); + } + + private void addQueryFilter(SemanticCorrectInfo semanticCorrectInfo) throws JSQLParserException { + String queryFilter = getQueryFilter(semanticCorrectInfo.getQueryFilters()); + + String preSql = semanticCorrectInfo.getSql(); + + if (StringUtils.isNotEmpty(queryFilter)) { + log.info("add queryFilter to preSql :{}", queryFilter); + Expression expression = CCJSqlParserUtil.parseCondExpression(queryFilter); + String sql = SqlParserUpdateHelper.addWhere(preSql, expression); + semanticCorrectInfo.setPreSql(preSql); + semanticCorrectInfo.setSql(sql); + } + } + + private void parserDateDiffFunction(SemanticCorrectInfo semanticCorrectInfo) { + String preSql = semanticCorrectInfo.getSql(); + semanticCorrectInfo.setPreSql(preSql); + String sql = SqlParserUpdateHelper.replaceFunction(preSql); + semanticCorrectInfo.setSql(sql); + } + + private void addDateIfNotExist(SemanticCorrectInfo semanticCorrectInfo) { + String sql = semanticCorrectInfo.getSql(); + List whereFields = SqlParserSelectHelper.getWhereFields(sql); + if (CollectionUtils.isEmpty(whereFields) || !whereFields.contains(TimeDimensionEnum.DAY.getName())) { + String currentDate = DSLDateHelper.getReferenceDate(semanticCorrectInfo.getParseInfo().getModelId()); + sql = SqlParserUpdateHelper.addWhere(sql, TimeDimensionEnum.DAY.getName(), currentDate); + } + semanticCorrectInfo.setSql(sql); + } + + private String getQueryFilter(QueryFilters queryFilters) { + if (Objects.isNull(queryFilters) || CollectionUtils.isEmpty(queryFilters.getFilters())) { + return null; + } + return queryFilters.getFilters().stream() + .map(filter -> { + String bizNameWrap = StringUtil.getSpaceWrap(filter.getBizName()); + String operatorWrap = StringUtil.getSpaceWrap(filter.getOperator().getValue()); + String valueWrap = StringUtil.getCommaWrap(filter.getValue().toString()); + return bizNameWrap + operatorWrap + valueWrap; + }) + .collect(Collectors.joining(Constants.AND_UPPER)); + } + + private void updateFieldValueByTechName(SemanticCorrectInfo semanticCorrectInfo) { SemanticSchema semanticSchema = ContextUtils.getBean(SchemaService.class).getSemanticSchema(); Long modelId = semanticCorrectInfo.getParseInfo().getModel().getId(); List dimensions = semanticSchema.getDimensions().stream() @@ -39,7 +105,6 @@ public class FieldValueCorrector extends BaseSemanticCorrector { return; } - private Map> getAliasAndBizNameToTechName(List dimensions) { if (CollectionUtils.isEmpty(dimensions)) { return new HashMap<>(); diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java index f286c1177..9b6c24422 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/parser/llm/dsl/LLMDslParser.java @@ -408,27 +408,20 @@ public class LLMDslParser implements SemanticParser { protected List getFieldNameList(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema, LLMParserConfig llmParserConfig) { + + Set results = getTopNFieldNames(modelId, semanticSchema, llmParserConfig); + + Set fieldNameList = getMatchedFieldNames(queryCtx, modelId, semanticSchema); + + results.addAll(fieldNameList); + return new ArrayList<>(results); + } + + protected Set getMatchedFieldNames(QueryContext queryCtx, Long modelId, SemanticSchema semanticSchema) { Map itemIdToName = getItemIdToName(modelId, semanticSchema); - - Set results = semanticSchema.getDimensions().stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel())) - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getDimensionTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - Set metrics = semanticSchema.getMetrics().stream() - .filter(schemaElement -> modelId.equals(schemaElement.getModel())) - .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) - .limit(llmParserConfig.getMetricTopN()) - .map(entry -> entry.getName()) - .collect(Collectors.toSet()); - - results.addAll(metrics); - List matchedElements = queryCtx.getMapInfo().getMatchedElements(modelId); if (CollectionUtils.isEmpty(matchedElements)) { - return new ArrayList<>(results); + return new HashSet<>(); } Set fieldNameList = matchedElements.stream() .filter(schemaElementMatch -> { @@ -447,13 +440,29 @@ public class LLMDslParser implements SemanticParser { }) .filter(name -> StringUtils.isNotEmpty(name) && !name.contains("%")) .collect(Collectors.toSet()); - results.addAll(fieldNameList); - return new ArrayList<>(results); + return fieldNameList; + } + + private Set getTopNFieldNames(Long modelId, SemanticSchema semanticSchema, + LLMParserConfig llmParserConfig) { + Set results = semanticSchema.getDimensions(modelId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getDimensionTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + Set metrics = semanticSchema.getMetrics(modelId).stream() + .sorted(Comparator.comparing(SchemaElement::getUseCnt).reversed()) + .limit(llmParserConfig.getMetricTopN()) + .map(entry -> entry.getName()) + .collect(Collectors.toSet()); + + results.addAll(metrics); + return results; } protected Map getItemIdToName(Long modelId, SemanticSchema semanticSchema) { - return semanticSchema.getDimensions().stream() - .filter(entry -> modelId.equals(entry.getModel())) + return semanticSchema.getDimensions(modelId).stream() .collect(Collectors.toMap(SchemaElement::getId, SchemaElement::getName, (value1, value2) -> value2)); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java index f88b24ea8..d27d36842 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/rest/ChatQueryController.java @@ -72,6 +72,7 @@ public class ChatQueryController { public Object queryData(@RequestBody QueryDataReq queryData, HttpServletRequest request, HttpServletResponse response) throws Exception { + queryData.setUser(UserHolder.findUser(request, response)); return queryService.executeDirectQuery(queryData, UserHolder.findUser(request, response)); } diff --git a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java index efcc45b77..40eb1aa51 100644 --- a/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java +++ b/chat/core/src/main/java/com/tencent/supersonic/chat/service/impl/QueryServiceImpl.java @@ -3,11 +3,14 @@ package com.tencent.supersonic.chat.service.impl; import com.tencent.supersonic.auth.api.authentication.pojo.User; import com.tencent.supersonic.chat.api.component.SchemaMapper; +import com.tencent.supersonic.chat.api.component.SemanticLayer; import com.tencent.supersonic.chat.api.component.SemanticQuery; import com.tencent.supersonic.chat.api.component.SemanticParser; import com.tencent.supersonic.chat.api.pojo.ChatContext; import com.tencent.supersonic.chat.api.pojo.QueryContext; import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; +import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; +import com.tencent.supersonic.chat.api.pojo.request.QueryFilter; import com.tencent.supersonic.chat.api.pojo.request.DimensionValueReq; import com.tencent.supersonic.chat.api.pojo.request.ExecuteQueryReq; import com.tencent.supersonic.chat.api.pojo.request.QueryReq; @@ -15,13 +18,15 @@ import com.tencent.supersonic.chat.api.pojo.response.EntityInfo; import com.tencent.supersonic.chat.api.pojo.response.ParseResp; import com.tencent.supersonic.chat.api.pojo.response.QueryResult; import com.tencent.supersonic.chat.api.pojo.response.QueryState; +import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; import com.tencent.supersonic.chat.api.pojo.response.SolvedQueryRecallResp; import com.tencent.supersonic.chat.persistence.dataobject.ChatParseDO; import com.tencent.supersonic.chat.persistence.dataobject.CostType; import com.tencent.supersonic.chat.persistence.dataobject.StatisticsDO; import com.tencent.supersonic.chat.query.QuerySelector; -import com.tencent.supersonic.chat.api.pojo.request.QueryDataReq; import com.tencent.supersonic.chat.query.QueryManager; +import com.tencent.supersonic.chat.query.llm.dsl.DslQuery; +import com.tencent.supersonic.chat.query.llm.dsl.LLMResp; import com.tencent.supersonic.chat.queryresponder.QueryResponder; import com.tencent.supersonic.chat.service.ChatService; import com.tencent.supersonic.chat.service.QueryService; @@ -29,25 +34,29 @@ import com.tencent.supersonic.chat.service.SemanticService; import com.tencent.supersonic.chat.service.StatisticsService; import com.tencent.supersonic.chat.utils.ComponentFactory; +import java.util.Map; import com.tencent.supersonic.semantic.api.model.response.ExplainResp; import java.util.List; import java.util.ArrayList; import java.util.Set; import java.util.HashSet; +import java.util.HashMap; import java.util.Comparator; import java.util.Objects; import java.util.stream.Collectors; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.common.pojo.DateConf; import com.tencent.supersonic.common.util.ContextUtils; import com.tencent.supersonic.common.util.JsonUtil; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.pojo.Filter; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import lombok.extern.slf4j.Slf4j; import org.apache.calcite.sql.parser.SqlParseException; -import org.springframework.beans.BeanUtils; +import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Primary; @@ -175,34 +184,26 @@ public class QueryServiceImpl implements QueryService { ChatContext chatCtx = chatService.getOrCreateContext(queryReq.getChatId()); chatCtx.setAgentId(queryReq.getAgentId()); Long startTime = System.currentTimeMillis(); - QueryResult queryResult = null; - try { - queryResult = semanticQuery.execute(queryReq.getUser()); - } catch (Exception e) { - log.error("query execute failed, queryText:{}", queryReq.getQueryText(), e); - queryResult = new QueryResult(); - queryResult.setQueryState(QueryState.INVALID); + QueryResult queryResult = semanticQuery.execute(queryReq.getUser()); + + if (queryResult != null) { + timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) + .interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build()); + saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), + queryReq.getUser().getName(), queryReq.getChatId().longValue()); + queryResult.setChatContext(parseInfo); + // update chat context after a successful semantic query + if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { + chatCtx.setParseInfo(parseInfo); + chatService.updateContext(chatCtx); + } + chatCtx.setQueryText(queryReq.getQueryText()); + chatCtx.setUser(queryReq.getUser().getName()); + chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); + } else { + chatService.deleteChatQuery(queryReq.getQueryId()); } - timeCostDOList.add(StatisticsDO.builder().cost((int) (System.currentTimeMillis() - startTime)) - .interfaceName(semanticQuery.getClass().getSimpleName()).type(CostType.QUERY.getType()).build()); - saveInfo(timeCostDOList, queryReq.getQueryText(), queryReq.getQueryId(), - queryReq.getUser().getName(), queryReq.getChatId().longValue()); - queryResult.setChatContext(parseInfo); - // update chat context after a successful semantic query - if (queryReq.isSaveAnswer() && QueryState.SUCCESS.equals(queryResult.getQueryState())) { - chatCtx.setParseInfo(parseInfo); - chatService.updateContext(chatCtx); - queryResponder.saveSolvedQuery(queryReq.getQueryText(), queryReq.getQueryId(), queryReq.getParseId()); - } - chatCtx.setQueryText(queryReq.getQueryText()); - chatCtx.setUser(queryReq.getUser().getName()); - chatService.updateQuery(queryReq.getQueryId(), queryResult, chatCtx); - if (!QueryState.SUCCESS.equals(queryResult.getQueryState())) { - List solvedQueryRecallResps = - queryResponder.recallSolvedQuery(queryReq.getQueryText()); - queryResult.setSimilarSolvedQuery(solvedQueryRecallResps); - } return queryResult; } @@ -273,8 +274,52 @@ public class QueryServiceImpl implements QueryService { @Override public QueryResult executeDirectQuery(QueryDataReq queryData, User user) throws SqlParseException { - SemanticQuery semanticQuery = QueryManager.createRuleQuery(queryData.getQueryMode()); - BeanUtils.copyProperties(queryData, semanticQuery.getParseInfo()); + ChatParseDO chatParseDO = chatService.getParseInfo(queryData.getQueryId(), + queryData.getUser().getName(), queryData.getParseId()); + SemanticParseInfo parseInfo = JsonUtil.toObject(chatParseDO.getParseInfo(), SemanticParseInfo.class); + if (!parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE)) { + if (CollectionUtils.isNotEmpty(queryData.getDimensions())) { + parseInfo.setDimensions(queryData.getDimensions()); + } + if (CollectionUtils.isNotEmpty(queryData.getMetrics())) { + parseInfo.setMetrics(queryData.getMetrics()); + } + if (CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { + parseInfo.setDimensionFilters(queryData.getDimensionFilters()); + } + } + if (Objects.nonNull(queryData.getDateInfo())) { + parseInfo.setDateInfo(queryData.getDateInfo()); + } + if (parseInfo.getQueryMode().equals(DslQuery.QUERY_MODE) + && CollectionUtils.isNotEmpty(queryData.getDimensionFilters())) { + Map> filedNameToValueMap = new HashMap<>(); + String json = JsonUtil.toString(parseInfo.getProperties().get(Constants.CONTEXT)); + DSLParseResult dslParseResult = JsonUtil.toObject(json, DSLParseResult.class); + LLMResp llmResp = dslParseResult.getLlmResp(); + String correctorSql = llmResp.getCorrectorSql(); + log.info("correctorSql before replacing:{}", correctorSql); + for (QueryFilter dslQueryFilter : queryData.getDimensionFilters()) { + for (QueryFilter queryFilter : parseInfo.getDimensionFilters()) { + if (dslQueryFilter.getBizName().equals(queryFilter.getBizName())) { + Map map = new HashMap<>(); + map.put(queryFilter.getValue().toString(), dslQueryFilter.getValue().toString()); + filedNameToValueMap.put(dslQueryFilter.getBizName(), map); + break; + } + } + } + log.info("filedNameToValueMap:{}", filedNameToValueMap); + correctorSql = SqlParserUpdateHelper.replaceValue(correctorSql, filedNameToValueMap); + log.info("correctorSql after replacing:{}", correctorSql); + llmResp.setCorrectorSql(correctorSql); + dslParseResult.setLlmResp(llmResp); + Map properties = new HashMap<>(); + properties.put(Constants.CONTEXT, dslParseResult); + parseInfo.setProperties(properties); + } + SemanticQuery semanticQuery = QueryManager.createQuery(parseInfo.getQueryMode()); + semanticQuery.setParseInfo(parseInfo); QueryResult queryResult = semanticQuery.execute(user); queryResult.setChatContext(semanticQuery.getParseInfo()); return queryResult; @@ -282,8 +327,6 @@ public class QueryServiceImpl implements QueryService { @Override public Object queryDimensionValue(DimensionValueReq dimensionValueReq, User user) throws Exception { - com.tencent.supersonic.semantic.query.service.QueryService queryService = - ContextUtils.getBean(com.tencent.supersonic.semantic.query.service.QueryService.class); QueryStructReq queryStructReq = new QueryStructReq(); DateConf dateConf = new DateConf(); @@ -307,7 +350,8 @@ public class QueryServiceImpl implements QueryService { dimensionFilters.add(dimensionFilter); queryStructReq.setDimensionFilters(dimensionFilters); } - QueryResultWithSchemaResp queryResultWithSchemaResp = queryService.queryByStructWithAuth(queryStructReq, user); + SemanticLayer semanticLayer = ComponentFactory.getSemanticLayer(); + QueryResultWithSchemaResp queryResultWithSchemaResp = semanticLayer.queryByStruct(queryStructReq, user); Set dimensionValues = new HashSet<>(); queryResultWithSchemaResp.getResultList().removeIf(o -> { if (dimensionValues.contains(o.get(dimensionValueReq.getBizName()))) { diff --git a/chat/core/src/main/python/few_shot_example/sql_exampler.py b/chat/core/src/main/python/few_shot_example/sql_exampler.py index 9569fe97a..454144f85 100644 --- a/chat/core/src/main/python/few_shot_example/sql_exampler.py +++ b/chat/core/src/main/python/few_shot_example/sql_exampler.py @@ -1,348 +1,371 @@ -examplars= [ - { "current_date":"2020-12-01", - "table_name":"内容库产品", - "fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", - "question":"比较jackjchen和robinlee在内容库的访问次数", - "prior_schema_links":"""['jackjchen'->用户名, 'robinlee'->用户名]""", +examplars = [ + { + "current_date": "2020-12-01", + "table_name": "内容库产品", + "fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", + "question": "比较jackjchen和robinlee在内容库的访问次数", + "prior_schema_links": """['jackjchen'->用户名, 'robinlee'->用户名]""", "analysis": """让我们一步一步地思考。在问题“比较jackjchen和robinlee在内容库的访问次数“中,我们被问: “比较jackjchen和robinlee”,所以我们需要column=[用户名] ”内容库的访问次数“,所以我们需要column=[访问次数] 基于table和columns,可能的cell values 是 = ['jackjchen', 'robinlee']。""", - "schema_links":"""["用户名", "访问次数", "'jackjchen'", "'robinlee'"]""", - "sql":"""select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jackjchen', 'robinlee') and 数据日期 = '2020-12-01' """ - }, - { "current_date":"2022-11-06", - "table_name":"内容库产品", - "fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", - "question":"内容库近12个月访问人数 按部门", - "prior_schema_links":"""[]""", + "schema_links": """["用户名", "访问次数", "'jackjchen'", "'robinlee'"]""", + "sql": """select 用户名, 访问次数 from 内容库产品 where 用户名 in ('jackjchen', 'robinlee') and 数据日期 = '2020-12-01' """, + }, + { + "current_date": "2022-11-06", + "table_name": "内容库产品", + "fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", + "question": "内容库近12个月访问人数 按部门", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库近12个月访问人数 按部门“中,我们被问: ”内容库近12个月“,所以我们需要column=[数据日期] “访问人数”,所以我们需要column=[访问人数] ”按部门“,所以我们需要column=[部门] 基于table和columns,可能的cell values 是 = [12]。""", - "schema_links":"""["访问人数", "部门", "数据日期", 12]""", - "sql":"""select 部门, 数据日期, 访问人数 from 内容库产品 where datediff('month', 数据日期, '2022-11-06') <= 12 """ - }, - { "current_date":"2023-04-21", - "table_name":"内容库产品", - "fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", - "question":"内容库美术部、技术研发部的访问时长", - "prior_schema_links":"""['美术部'->部门, '技术研发部'->部门]""", + "schema_links": """["访问人数", "部门", "数据日期", 12]""", + "sql": """select 部门, 数据日期, 访问人数 from 内容库产品 where datediff('month', 数据日期, '2022-11-06') <= 12 """, + }, + { + "current_date": "2023-04-21", + "table_name": "内容库产品", + "fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", + "question": "内容库美术部、技术研发部的访问时长", + "prior_schema_links": """['美术部'->部门, '技术研发部'->部门]""", "analysis": """让我们一步一步地思考。在问题“内容库美术部、技术研发部的访问时长“中,我们被问: “访问时长”,所以我们需要column=[访问时长] ”内容库美术部、技术研发部“,所以我们需要column=[部门] 基于table和columns,可能的cell values 是 = ['美术部', '技术研发部']。""", - "schema_links":"""["访问时长", "部门", "'美术部'", "'技术研发部'"]""", - "sql":"""select 部门, 访问时长 from 内容库产品 where 部门 in ('美术部', '技术研发部') and 数据日期 = '2023-04-21' """ - }, - { "current_date":"2023-08-21", - "table_name":"严选", - "fields_list":"""["严选版权归属系", "付费模式", "结算播放份额", "付费用户结算播放份额", "数据日期"]""", - "question":"近3天海田飞系MPPM结算播放份额", - "prior_schema_links":"""['海田飞系'->严选版权归属系]""", + "schema_links": """["访问时长", "部门", "'美术部'", "'技术研发部'"]""", + "sql": """select 部门, 访问时长 from 内容库产品 where 部门 in ('美术部', '技术研发部') and 数据日期 = '2023-04-21' """, + }, + { + "current_date": "2023-08-21", + "table_name": "严选", + "fields_list": """["严选版权归属系", "付费模式", "结算播放份额", "付费用户结算播放份额", "数据日期"]""", + "question": "近3天海田飞系MPPM结算播放份额", + "prior_schema_links": """['海田飞系'->严选版权归属系]""", "analysis": """让我们一步一步地思考。在问题“近3天海田飞系MPPM结算播放份额“中,我们被问: “MPPM结算播放份额”,所以我们需要column=[结算播放份额] ”海田飞系“,所以我们需要column=[严选版权归属系] ”近3天“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = ['海田飞系', 3]。""", - "schema_links":"""["结算播放份额", "严选版权归属系", "数据日期", "'海田飞系'", 3]""", - "sql":"""select 严选版权归属系, 结算播放份额 from 严选 where 严选版权归属系 = '海田飞系' and datediff('day', 数据日期, '2023-08-21') <= 3 """ - }, - { "current_date":"2023-05-22", - "table_name":"歌曲库", - "fields_list":"""["是否潮流人歌曲", "C音歌曲ID", "C音歌曲MID", "歌曲名", "歌曲版本", "语种", "歌曲类型", "翻唱类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "结算播放量", "运营播放量", "付费用户结算播放量", "历史累计结算播放量", "运营搜播量", "结算搜播量", "运营完播量", "运营推播量", "近7日复播率", "日均搜播量", "数据日期"]""", - "question":"对比近7天翻唱版和纯音乐的歌曲播放量", - "prior_schema_links":"""['纯音乐'->语种, '翻唱版'->歌曲版本]""", + "schema_links": """["结算播放份额", "严选版权归属系", "数据日期", "'海田飞系'", 3]""", + "sql": """select 严选版权归属系, 结算播放份额 from 严选 where 严选版权归属系 = '海田飞系' and datediff('day', 数据日期, '2023-08-21') <= 3 """, + }, + { + "current_date": "2023-05-22", + "table_name": "歌曲库", + "fields_list": """["是否潮流人歌曲", "C音歌曲ID", "C音歌曲MID", "歌曲名", "歌曲版本", "语种", "歌曲类型", "翻唱类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "结算播放量", "运营播放量", "付费用户结算播放量", "历史累计结算播放量", "运营搜播量", "结算搜播量", "运营完播量", "运营推播量", "近7日复播率", "日均搜播量", "数据日期"]""", + "question": "对比近7天翻唱版和纯音乐的歌曲播放量", + "prior_schema_links": """['纯音乐'->语种, '翻唱版'->歌曲版本]""", "analysis": """让我们一步一步地思考。在问题“对比近3天翻唱版和纯音乐的歌曲播放量“中,我们被问: “歌曲播放量”,所以我们需要column=[结算播放量] ”翻唱版“,所以我们需要column=[歌曲版本] ”和纯音乐的歌曲“,所以我们需要column=[语种] ”近7天“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = ['翻唱版', '纯音乐', 7]。""", - "schema_links":"""["结算播放量", "歌曲版本", "语种", "数据日期", "'翻唱版'", "'纯音乐'", 7]""", - "sql":"""select 歌曲版本, 语种, 结算播放量 from 歌曲库 where 歌曲版本 = '翻唱版' and 语种 = '纯音乐' and datediff('day', 数据日期, '2023-05-22') <= 7 """ - }, - { "current_date":"2023-05-31", - "table_name":"艺人库", - "fields_list":"""["上下架状态", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "活跃区域", "年龄", "歌手才能", "歌手风格", "粉丝数", "潮音粉丝数", "超声波粉丝数", "推博粉丝数", "超声波歌曲数", "在架歌曲数", "超声波分享数", "独占歌曲数", "超声波在架歌曲评论数", "有播放量歌曲数", "数据日期"]""", - "question":"对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数", - "prior_schema_links":"""['1527896'->MPPM歌手ID, '1565463'->MPPM歌手ID, '2141459'->MPPM歌手ID]""", + "schema_links": """["结算播放量", "歌曲版本", "语种", "数据日期", "'翻唱版'", "'纯音乐'", 7]""", + "sql": """select 歌曲版本, 语种, 结算播放量 from 歌曲库 where 歌曲版本 = '翻唱版' and 语种 = '纯音乐' and datediff('day', 数据日期, '2023-05-22') <= 7 """, + }, + { + "current_date": "2023-05-31", + "table_name": "艺人库", + "fields_list": """["上下架状态", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "活跃区域", "年龄", "歌手才能", "歌手风格", "粉丝数", "潮音粉丝数", "超声波粉丝数", "推博粉丝数", "超声波歌曲数", "在架歌曲数", "超声波分享数", "独占歌曲数", "超声波在架歌曲评论数", "有播放量歌曲数", "数据日期"]""", + "question": "对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数", + "prior_schema_links": """['1527896'->MPPM歌手ID, '1565463'->MPPM歌手ID, '2141459'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“对比一下陈拙悬、孟梅琦、赖媚韵的粉丝数“中,我们被问: “粉丝数”,所以我们需要column=[粉丝数] ”陈拙悬、孟梅琦、赖媚韵“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = ['陈拙悬', '孟梅琦', '赖媚韵']。""", - "schema_links":"""["粉丝数", "歌手名", "'陈拙悬'", "'孟梅琦'", "'赖媚韵'"]""", - "sql":"""select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('陈拙悬', '孟梅琦', '赖媚韵') and 数据日期 = '2023-05-31' """ - }, - { "current_date":"2023-07-31", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"播放量大于1万的歌曲有多少", - "prior_schema_links":"""[]""", + "schema_links": """["粉丝数", "歌手名", "'陈拙悬'", "'孟梅琦'", "'赖媚韵'"]""", + "sql": """select 歌手名, 粉丝数 from 艺人库 where 歌手名 in ('陈拙悬', '孟梅琦', '赖媚韵') and 数据日期 = '2023-05-31' """, + }, + { + "current_date": "2023-07-31", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "播放量大于1万的歌曲有多少", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“播放量大于1万的歌曲有多少“中,我们被问: “歌曲有多少”,所以我们需要column=[歌曲名] ”播放量大于1万的“,所以我们需要column=[结算播放量] 基于table和columns,可能的cell values 是 = [10000]。""", - "schema_links":"""["歌曲名", "结算播放量", 10000]""", - "sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 10000 and 数据日期 = '2023-07-31' """ - }, - { "current_date":"2023-07-31", - "table_name":"内容库产品", - "fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", - "question":"内容库访问时长小于1小时,且来自美术部的用户是哪些", - "prior_schema_links":"""['美术部'->部门]""", + "schema_links": """["歌曲名", "结算播放量", 10000]""", + "sql": """select 歌曲名 from 歌曲库 where 结算播放量 > 10000 and 数据日期 = '2023-07-31' """, + }, + { + "current_date": "2023-07-31", + "table_name": "内容库产品", + "fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", + "question": "内容库访问时长小于1小时,且来自美术部的用户是哪些", + "prior_schema_links": """['美术部'->部门]""", "analysis": """让我们一步一步地思考。在问题“内容库访问时长小于1小时,且来自美术部的用户是哪些“中,我们被问: “用户是哪些”,所以我们需要column=[用户名] ”美术部的“,所以我们需要column=[部门] ”访问时长小于1小时“,所以我们需要column=[访问时长] 基于table和columns,可能的cell values 是 = ['美术部', 1]。""", - "schema_links":"""["用户名", "部门", "访问时长", "'美术部'", 1]""", - "sql":"""select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1 and 数据日期 = '2023-07-31' """ - }, - { "current_date":"2023-08-31", - "table_name":"内容库产品", - "fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", - "question":"内容库pv最高的用户有哪些", - "prior_schema_links":"""[]""", + "schema_links": """["用户名", "部门", "访问时长", "'美术部'", 1]""", + "sql": """select 用户名 from 内容库产品 where 部门 = '美术部' and 访问时长 < 1 and 数据日期 = '2023-07-31' """, + }, + { + "current_date": "2023-08-31", + "table_name": "内容库产品", + "fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", + "question": "内容库pv最高的用户有哪些", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库pv最高的用户有哪些“中,我们被问: “用户有哪些”,所以我们需要column=[用户名] ”pv最高的“,所以我们需要column=[访问次数] 基于table和columns,可能的cell values 是 = []。""", - "schema_links":"""["用户名", "访问次数"]""", - "sql":"""select 用户名 from 内容库产品 where 数据日期 = '2023-08-31' order by 访问次数 desc limit 10 """ - }, - { "current_date":"2023-08-31", - "table_name":"艺人库", - "fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""", - "question":"近90天袁亚伟播放量平均值是多少", - "prior_schema_links":"""['152789226'->MPPM歌手ID]""", + "schema_links": """["用户名", "访问次数"]""", + "sql": """select 用户名 from 内容库产品 where 数据日期 = '2023-08-31' order by 访问次数 desc limit 10 """, + }, + { + "current_date": "2023-08-31", + "table_name": "艺人库", + "fields_list": """["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""", + "question": "近90天袁亚伟播放量平均值是多少", + "prior_schema_links": """['152789226'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“近90天袁亚伟播放量平均值是多少“中,我们被问: “播放量平均值是多少”,所以我们需要column=[结算播放量] ”袁亚伟“,所以我们需要column=[歌手名] ”近90天“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = ['袁亚伟', 90]。""", - "schema_links":"""["结算播放量", "歌手名", "数据日期", "'袁亚伟'", 90]""", - "sql":"""select avg(结算播放量) from 艺人库 where 歌手名 = '袁亚伟' and datediff('day', 数据日期, '2023-08-31') <= 90 """ - }, - { "current_date":"2023-08-31", - "table_name":"艺人库", - "fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""", - "question":"周倩倩近7天结算播放量总和是多少", - "prior_schema_links":"""['199509'->MPPM歌手ID]""", + "schema_links": """["结算播放量", "歌手名", "数据日期", "'袁亚伟'", 90]""", + "sql": """select avg(结算播放量) from 艺人库 where 歌手名 = '袁亚伟' and datediff('day', 数据日期, '2023-08-31') <= 90 """, + }, + { + "current_date": "2023-08-31", + "table_name": "艺人库", + "fields_list": """["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""", + "question": "周倩倩近7天结算播放量总和是多少", + "prior_schema_links": """['199509'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“周倩倩近7天结算播放量总和是多少“中,我们被问: “结算播放量总和是多少”,所以我们需要column=[结算播放量] ”周倩倩“,所以我们需要column=[歌手名] ”近7天“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = ['周倩倩', 7]。""", - "schema_links":"""["结算播放量", "歌手名", "数据日期", "'周倩倩'", 7]""", - "sql":"""select sum(结算播放量) from 艺人库 where 歌手名 = '周倩倩' and datediff('day', 数据日期, '2023-08-31') <= 7 """ - }, - { "current_date":"2023-09-14", - "table_name":"内容库产品", - "fields_list":"""["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", - "question":"内容库访问次数大于1k的部门是哪些", - "prior_schema_links":"""[]""", + "schema_links": """["结算播放量", "歌手名", "数据日期", "'周倩倩'", 7]""", + "sql": """select sum(结算播放量) from 艺人库 where 歌手名 = '周倩倩' and datediff('day', 数据日期, '2023-08-31') <= 7 """, + }, + { + "current_date": "2023-09-14", + "table_name": "内容库产品", + "fields_list": """["部门", "模块", "用户名", "访问次数", "访问人数", "访问时长", "数据日期"]""", + "question": "内容库访问次数大于1k的部门是哪些", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库访问次数大于1k的部门是哪些“中,我们被问: “部门是哪些”,所以我们需要column=[部门] ”访问次数大于1k的“,所以我们需要column=[访问次数] 基于table和columns,可能的cell values 是 = [1000]。""", - "schema_links":"""["部门", "访问次数", 1000]""", - "sql":"""select 部门 from 内容库产品 where 访问次数 > 1000 and 数据日期 = '2023-09-14' """ - }, - { "current_date":"2023-09-18", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "MPPM歌手ID", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"陈亿训唱的所有的播放量大于20k的孤勇者有哪些", - "prior_schema_links":"""['199509'->MPPM歌手ID, '1527123'->MPPM歌曲ID]""", + "schema_links": """["部门", "访问次数", 1000]""", + "sql": """select 部门 from 内容库产品 where 访问次数 > 1000 and 数据日期 = '2023-09-14' """, + }, + { + "current_date": "2023-09-18", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "MPPM歌手ID", "歌曲版本", "歌曲类型", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "陈亿训唱的所有的播放量大于20k的孤勇者有哪些", + "prior_schema_links": """['199509'->MPPM歌手ID, '1527123'->MPPM歌曲ID]""", "analysis": """让我们一步一步地思考。在问题“陈亿训唱的所有的播放量大于20k的孤勇者有哪些“中,我们被问: “孤勇者有哪些”,所以我们需要column=[歌曲名] ”播放量大于20k的“,所以我们需要column=[结算播放量] ”陈亿训唱的“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = [20000, '陈亿训', '孤勇者']。""", - "schema_links":"""["歌曲名", "结算播放量", "歌手名", 20000, "'陈亿训'", "'孤勇者'"]""", - "sql":"""select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '陈亿训' and 歌曲名 = '孤勇者' and 数据日期 = '2023-09-18' """ - }, - { "current_date":"2023-09-18", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"周洁轮去年发布的歌曲有哪些", - "prior_schema_links":"""['23109'->MPPM歌手ID]""", + "schema_links": """["歌曲名", "结算播放量", "歌手名", 20000, "'陈亿训'", "'孤勇者'"]""", + "sql": """select 歌曲名 from 歌曲库 where 结算播放量 > 20000 and 歌手名 = '陈亿训' and 歌曲名 = '孤勇者' and 数据日期 = '2023-09-18' """, + }, + { + "current_date": "2023-09-18", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "周洁轮去年发布的歌曲有哪些", + "prior_schema_links": """['23109'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“周洁轮去年发布的歌曲有哪些“中,我们被问: “歌曲有哪些”,所以我们需要column=[歌曲名] ”去年发布的“,所以我们需要column=[发布时间] ”周洁轮“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = ['周洁轮', 1]。""", - "schema_links":"""["歌曲名", "发布时间", "歌手名", 1, "'周洁轮'"]""", - "sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发布时间, '2023-09-18') <= 1 and 歌手名 = '周洁轮' and 数据日期 = '2023-09-18' """ - }, - { "current_date":"2023-09-11", - "table_name":"艺人库", - "fields_list":"""["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "签约日期", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""", - "question":"我想要近半年签约的播放量前十的歌手有哪些", - "prior_schema_links":"""[]""", + "schema_links": """["歌曲名", "发布时间", "歌手名", 1, "'周洁轮'"]""", + "sql": """select 歌曲名 from 歌曲库 where datediff('year', 发布时间, '2023-09-18') <= 1 and 歌手名 = '周洁轮' and 数据日期 = '2023-09-18' """, + }, + { + "current_date": "2023-09-11", + "table_name": "艺人库", + "fields_list": """["播放量层级", "播放量单调性", "播放量方差", "播放量突增类型", "播放量集中度", "歌手名", "歌手等级", "歌手类型", "歌手来源", "签约日期", "MPPM潮流人等级", "结算播放量", "运营播放量", "历史累计结算播放量", "有播放量歌曲数", "历史累计运营播放量", "付费用户结算播放量", "结算播放量占比", "运营播放份额", "免费用户结算播放占比", "完播量", "数据日期"]""", + "question": "我想要近半年签约的播放量前十的歌手有哪些", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“我想要近半年签约的播放量前十的歌手“中,我们被问: “歌手有哪些”,所以我们需要column=[歌手名] ”播放量前十的“,所以我们需要column=[结算播放量] ”近半年签约的“,所以我们需要column=[签约日期] 基于table和columns,可能的cell values 是 = [0.5, 10]。""", - "schema_links":"""["歌手名", "结算播放量", "签约日期", 0.5, 10]""", - "sql":"""select 歌手名 from 艺人库 where datediff('year', 签约日期, '2023-09-11') <= 0.5 and 数据日期 = '2023-09-11' order by 结算播放量 desc limit 10""" - }, - { "current_date":"2023-08-12", - "table_name":"歌曲库", + "schema_links": """["歌手名", "结算播放量", "签约日期", 0.5, 10]""", + "sql": """select 歌手名 from 艺人库 where datediff('year', 签约日期, '2023-09-11') <= 0.5 and 数据日期 = '2023-09-11' order by 结算播放量 desc limit 10""", + }, + { + "current_date": "2023-08-12", + "table_name": "歌曲库", "fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""", - "question":"最近一年发行的歌曲中,有哪些在近7天播放超过一千万的", - "prior_schema_links":"""[]""", + "question": "最近一年发行的歌曲中,有哪些在近7天播放超过一千万的", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“最近一年发行的歌曲中,有哪些在近7天播放超过一千万的“中,我们被问: “发行的歌曲中,有哪些”,所以我们需要column=[歌曲名] ”最近一年发行的“,所以我们需要column=[发行日期] ”在近7天播放超过一千万的“,所以我们需要column=[数据日期, 结算播放量] 基于table和columns,可能的cell values 是 = [1, 10000000]""", - "schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 1, 10000000]""", - "sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 1 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""" - }, - { "current_date":"2023-08-12", - "table_name":"歌曲库", + "schema_links": """["歌曲名", "发行日期", "数据日期", "结算播放量", 1, 10000000]""", + "sql": """select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 1 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""", + }, + { + "current_date": "2023-08-12", + "table_name": "歌曲库", "fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""", - "question":"今年以来发行的歌曲中,有哪些在近7天播放超过一千万的", - "prior_schema_links":"""[]""", + "question": "今年以来发行的歌曲中,有哪些在近7天播放超过一千万的", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“今年以来发行的歌曲中,有哪些在近7天播放超过一千万的“中,我们被问: “发行的歌曲中,有哪些”,所以我们需要column=[歌曲名] ”今年以来发行的“,所以我们需要column=[发行日期] ”在近7天播放超过一千万的“,所以我们需要column=[数据日期, 结算播放量] 基于table和columns,可能的cell values 是 = [0, 7, 10000000]""", - "schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 0, 7, 10000000]""", - "sql":"""select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 0 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""" - }, - { "current_date":"2023-08-12", - "table_name":"歌曲库", + "schema_links": """["歌曲名", "发行日期", "数据日期", "结算播放量", 0, 7, 10000000]""", + "sql": """select 歌曲名 from 歌曲库 where datediff('year', 发行日期, '2023-08-12') <= 0 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""", + }, + { + "current_date": "2023-08-12", + "table_name": "歌曲库", "fields_list": """["发行日期", "歌曲语言", "歌曲来源", "歌曲流派", "歌曲名", "歌曲版本", "歌曲类型", "发行时间", "数据日期"]""", - "question":"2023年以来发行的歌曲中,有哪些在近7天播放超过一千万的", - "prior_schema_links":"""['514129144'->MPPM歌曲ID]""", + "question": "2023年以来发行的歌曲中,有哪些在近7天播放超过一千万的", + "prior_schema_links": """['514129144'->MPPM歌曲ID]""", "analysis": """让我们一步一步地思考。在问题“2023年以来发行的歌曲中,有哪些在近7天播放超过一千万的“中,我们被问: “发行的歌曲中,有哪些”,所以我们需要column=[歌曲名] ”2023年以来发行的“,所以我们需要column=[发行日期] ”在近7天播放超过一千万的“,所以我们需要column=[数据日期, 结算播放量] 基于table和columns,可能的cell values 是 = [2023, 7, 10000000]""", - "schema_links":"""["歌曲名", "发行日期", "数据日期", "结算播放量", 2023, 7, 10000000]""", - "sql":"""select 歌曲名 from 歌曲库 where YEAR(发行日期) >= 2023 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""" - }, - { "current_date":"2023-08-01", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"周洁轮2023年6月之后发布的歌曲有哪些", - "prior_schema_links":"""['23109'->MPPM歌手ID]""", + "schema_links": """["歌曲名", "发行日期", "数据日期", "结算播放量", 2023, 7, 10000000]""", + "sql": """select 歌曲名 from 歌曲库 where YEAR(发行日期) >= 2023 and datediff('day', 数据日期, '2023-08-12') <= 7 and 结算播放量 > 10000000""", + }, + { + "current_date": "2023-08-01", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "周洁轮2023年6月之后发布的歌曲有哪些", + "prior_schema_links": """['23109'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“周洁轮2023年6月之后发布的歌曲有哪些“中,我们被问: “歌曲有哪些”,所以我们需要column=[歌曲名] ”2023年6月之后发布的“,所以我们需要column=[发布时间] ”周洁轮“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = ['周洁轮', 2023, 6]。""", - "schema_links":"""["歌曲名", "发布时间", "歌手名", "周洁轮", 2023, 6]""", - "sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 6 and 歌手名 = '周洁轮' and 数据日期 = '2023-08-01' """ - }, - { "current_date":"2023-08-01", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?", - "prior_schema_links":"""['2312311'->MPPM歌手ID]""", + "schema_links": """["歌曲名", "发布时间", "歌手名", "周洁轮", 2023, 6]""", + "sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 6 and 歌手名 = '周洁轮' and 数据日期 = '2023-08-01' """, + }, + { + "current_date": "2023-08-01", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?", + "prior_schema_links": """['2312311'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“邓梓琦在2023年1月5日之后发布的歌曲中,有哪些播放量大于500W的?“中,我们被问: “播放量大于500W的”,所以我们需要column=[结算播放量] ”邓梓琦在2023年1月5日之后发布的“,所以我们需要column=[发布时间] ”邓梓琦“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = ['邓梓琦', 2023, 1, 5, 5000000]。""", - "schema_links":"""["结算播放量", "发布时间", "歌手名", "邓梓琦", 2023, 1, 5, 5000000]""", - "sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 1 and DAY(发布时间) >= 5 and 歌手名 = '邓梓琦' and 结算播放量 > 5000000 and 数据日期 = '2023-08-01'""" - }, - { "current_date":"2023-09-17", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"2023年6月以后,张亮英播放量大于200万的歌曲有哪些?", - "prior_schema_links":"""['45453'->MPPM歌手ID]""", + "schema_links": """["结算播放量", "发布时间", "歌手名", "邓梓琦", 2023, 1, 5, 5000000]""", + "sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2023 and MONTH(发布时间) >= 1 and DAY(发布时间) >= 5 and 歌手名 = '邓梓琦' and 结算播放量 > 5000000 and 数据日期 = '2023-08-01'""", + }, + { + "current_date": "2023-09-17", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "2023年6月以后,张亮英播放量大于200万的歌曲有哪些?", + "prior_schema_links": """['45453'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“2023年6月以后,张亮英播放量大于200万的歌曲有哪些?“中,我们被问: “播放量大于200万的”,所以我们需要column=[结算播放量] ”2023年6月以后,张亮英“,所以我们需要column=[数据日期, 歌手名] ”歌曲有哪些“,所以我们需要column=[歌曲名] 基于table和columns,可能的cell values 是 = ['张亮英', 2023, 6, 2000000]。""", - "schema_links":"""["结算播放量", "数据日期", "歌手名", "张亮英", 2023, 6, 2000000]""", - "sql":"""select 歌曲名 from 歌曲库 where YEAR(数据日期) >= 2023 and MONTH(数据日期) >= 6 and 歌手名 = '张亮英' and 结算播放量 > 2000000 """ - }, - { "current_date":"2023-08-16", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些", - "prior_schema_links":"""['23109'->MPPM歌手ID]""", + "schema_links": """["结算播放量", "数据日期", "歌手名", "张亮英", 2023, 6, 2000000]""", + "sql": """select 歌曲名 from 歌曲库 where YEAR(数据日期) >= 2023 and MONTH(数据日期) >= 6 and 歌手名 = '张亮英' and 结算播放量 > 2000000 """, + }, + { + "current_date": "2023-08-16", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些", + "prior_schema_links": """['23109'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“2021年6月以后发布的李雨纯的播放量大于20万的歌曲有哪些“中,我们被问: “播放量大于20万的”,所以我们需要column=[结算播放量] ”2021年6月以后发布的“,所以我们需要column=[发布时间] ”李雨纯“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = ['李雨纯', 2021, 6, 200000]。""", - "schema_links":"""["结算播放量", "发布时间", "歌手名", "李雨纯", 2021, 6, 200000]""", - "sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2021 and MONTH(发布时间) >= 6 and 歌手名 = '李雨纯' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'""" - }, - { "current_date":"2023-08-16", - "table_name":"歌曲库", - "fields_list":"""["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", - "question":"刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些", - "prior_schema_links":"""['4234234'->MPPM歌手ID]""", + "schema_links": """["结算播放量", "发布时间", "歌手名", "李雨纯", 2021, 6, 200000]""", + "sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 2021 and MONTH(发布时间) >= 6 and 歌手名 = '李雨纯' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'""", + }, + { + "current_date": "2023-08-16", + "table_name": "歌曲库", + "fields_list": """["歌曲名", "歌曲版本", "歌手名", "歌曲类型", "发布时间", "MPPM歌曲ID", "是否严选窄口径歌曲", "是否严选宽口径歌曲", "是否潮流人歌曲", "超声波歌曲ID", "C音歌曲ID", "C音歌曲MID", "结算播放量", "运营播放量", "分享量", "收藏量", "运营搜播量", "结算搜播量", "拉新用户数", "拉活用户数", "分享率", "结算播放份额", "数据日期"]""", + "question": "刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些", + "prior_schema_links": """['4234234'->MPPM歌手ID]""", "analysis": """让我们一步一步地思考。在问题“刘锝桦在1992年4月2日到2020年5月2日之间发布的播放量大于20万的歌曲有哪些“中,我们被问: “播放量大于20万的”,所以我们需要column=[结算播放量] ”1992年4月2日到2020年5月2日之间发布的“,所以我们需要column=[发布时间] ”刘锝桦“,所以我们需要column=[歌手名] 基于table和columns,可能的cell values 是 = ['刘锝桦', 1992, 4, 2, 2020, 5, 2, 200000]。""", - "schema_links":"""["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""", - "sql":"""select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'""" - }, + "schema_links": """["结算播放量", "发布时间", "歌手名", "刘锝桦", 1992, 4, 2, 2020, 5, 2, 200000]""", + "sql": """select 歌曲名 from 歌曲库 where YEAR(发布时间) >= 1992 and MONTH(发布时间) >= 4 and DAY(发布时间) >= 2 and YEAR(发布时间) <= 2020 and MONTH(发布时间) <= 5 and DAY(发布时间) <= 2 and 歌手名 = '刘锝桦' and 结算播放量 > 200000 and 数据日期 = '2023-08-16'""", + }, { - "current_date":"2023-09-04", - "table_name":"内容库产品", - "fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", - "question":"内容库近30天访问次数的平均数", - "prior_schema_links":"""[]""", + "current_date": "2023-09-04", + "table_name": "内容库产品", + "fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", + "question": "内容库近30天访问次数的平均数", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库近30天访问次数的平均数“中,我们被问: “访问次数的平均数”,所以我们需要column=[访问次数] ”内容库近30天“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = [30]。""", - "schema_links":"""["访问次数", "数据日期", 30]""", - "sql":"""select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """ - }, + "schema_links": """["访问次数", "数据日期", 30]""", + "sql": """select avg(访问次数) from 内容库产品 where datediff('day', 数据日期, '2023-09-04') <= 30 """, + }, { - "current_date":"2023-09-04", - "table_name":"内容库产品", - "fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", - "question":"内容库近半年哪个月的访问次数汇总最高", - "prior_schema_links":"""[]""", + "current_date": "2023-09-04", + "table_name": "内容库产品", + "fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", + "question": "内容库近半年哪个月的访问次数汇总最高", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库近半年哪个月的访问次数汇总最高“中,我们被问: “访问次数汇总最高”,所以我们需要column=[访问次数] ”内容库近半年“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = [0.5]。""", - "schema_links":"""["访问次数", "数据日期", 0.5]""", - "sql":"""select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1 """ - }, + "schema_links": """["访问次数", "数据日期", 0.5]""", + "sql": """select MONTH(数据日期), sum(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) order by sum(访问次数) desc limit 1 """, + }, { - "current_date":"2023-09-04", - "table_name":"内容库产品", - "fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", - "question":"内容库近半年每个月的平均访问次数", - "prior_schema_links":"""[]""", + "current_date": "2023-09-04", + "table_name": "内容库产品", + "fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", + "question": "内容库近半年每个月的平均访问次数", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库近半年每个月的平均访问次数“中,我们被问: “每个月的平均访问次数”,所以我们需要column=[访问次数] ”内容库近半年“,所以我们需要column=[数据日期] 基于table和columns,可能的cell values 是 = [0.5]。""", - "schema_links":"""["访问次数", "数据日期", 0.5]""", - "sql":"""select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) """ - }, + "schema_links": """["访问次数", "数据日期", 0.5]""", + "sql": """select MONTH(数据日期), avg(访问次数) from 内容库产品 where datediff('year', 数据日期, '2023-09-04') <= 0.5 group by MONTH(数据日期) """, + }, { - "current_date":"2023-09-10", - "table_name":"内容库产品", - "fields_list":"""["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", - "question":"内容库 按部门统计访问次数 top10 的部门", - "prior_schema_links":"""[]""", + "current_date": "2023-09-10", + "table_name": "内容库产品", + "fields_list": """["用户名", "部门", "模块", "访问时长", "访问次数", "访问人数", "数据日期"]""", + "question": "内容库 按部门统计访问次数 top10 的部门", + "prior_schema_links": """[]""", "analysis": """让我们一步一步地思考。在问题“内容库 按部门统计访问次数 top10 的部门“中,我们被问: “访问次数 top10 的部门”,所以我们需要column=[访问次数] ”内容库 按部门统计“,所以我们需要column=[部门] 基于table和columns,可能的cell values 是 = [10]。""", - "schema_links":"""["访问次数", "部门", 10]""", - "sql":"""select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10 """ - } -] \ No newline at end of file + "schema_links": """["访问次数", "部门", 10]""", + "sql": """select 部门, sum(访问次数) from 内容库产品 group by 部门 order by sum(访问次数) desc limit 10 """, + }, +] diff --git a/chat/core/src/main/python/plugin_call/prompt_construct.py b/chat/core/src/main/python/plugin_call/prompt_construct.py index 61b82cfc7..87a4d6163 100644 --- a/chat/core/src/main/python/plugin_call/prompt_construct.py +++ b/chat/core/src/main/python/plugin_call/prompt_construct.py @@ -14,7 +14,7 @@ def construct_plugin_prompt(tool_config): tool_name = tool_config["name"] tool_description = tool_config["description"] tool_examples = tool_config["examples"] - + prompt = "【工具名称】\n" + tool_name + "\n" prompt += "【工具描述】\n" + tool_description + "\n" @@ -23,6 +23,7 @@ def construct_plugin_prompt(tool_config): prompt += example + "\n" return prompt + def construct_plugin_pool_prompt(tool_config_list): tool_explain_list = [] for tool_config in tool_config_list: @@ -35,15 +36,20 @@ def construct_plugin_pool_prompt(tool_config_list): def construct_task_prompt(query_text, tool_explain_list_str): - instruction = """问题为:{query_text}\n请根据问题和工具的描述,选择对应的工具,完成任务。请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘""".format(query_text=query_text) + instruction = """问题为:{query_text}\n请根据问题和工具的描述,选择对应的工具,完成任务。请注意,只能选择1个工具。请一步一步地分析选择工具的原因(每个工具的【工具适用问题示例】是选择的重要参考依据),并给出最终选择,输出格式为json,key为’分析过程‘, ’选择工具‘""".format( + query_text=query_text + ) + + prompt = "工具选择如下:\n\n{tool_explain_list_str}\n\n【任务说明】\n{instruction}".format( + instruction=instruction, tool_explain_list_str=tool_explain_list_str + ) - prompt = "工具选择如下:\n\n{tool_explain_list_str}\n\n【任务说明】\n{instruction}".format(instruction=instruction, tool_explain_list_str=tool_explain_list_str) - return prompt -def plugin_selection_output_parse(llm_output: str)-> Union[Mapping[str, str], None]: + +def plugin_selection_output_parse(llm_output: str) -> Union[Mapping[str, str], None]: try: - pattern = r'\{[^{}]+\}' + pattern = r"\{[^{}]+\}" find_result = re.findall(pattern, llm_output) result = find_result[0].strip() @@ -52,20 +58,24 @@ def plugin_selection_output_parse(llm_output: str)-> Union[Mapping[str, str], No result_dict = json.loads(result) print("result_dict: ", result_dict) - key_mapping = { - "分析过程":"analysis", - "选择工具":"toolSelection" - } + key_mapping = {"分析过程": "analysis", "选择工具": "toolSelection"} - converted_result_dict = {key_mapping[key]: value for key, value in result_dict.items() if key in key_mapping} + converted_result_dict = { + key_mapping[key]: value + for key, value in result_dict.items() + if key in key_mapping + } except Exception as e: print(e) converted_result_dict = None return converted_result_dict - -def plugins_config_format_convert(plugin_config_list: List[Mapping[str, Any]]) -> List[Mapping[str, Any]]: + + +def plugins_config_format_convert( + plugin_config_list: List[Mapping[str, Any]] +) -> List[Mapping[str, Any]]: plugin_config_list_new = [] for plugin_config in plugin_config_list: plugin_config_new = dict() @@ -75,7 +85,9 @@ def plugins_config_format_convert(plugin_config_list: List[Mapping[str, Any]]) - parameters = plugin_config["parameters"] examples_str = "\n".join(examples) - description_new = """{plugin_desc}\n\n例如能够处理如下问题:\n{examples_str}""".format(plugin_desc=description, examples_str=examples_str) + description_new = """{plugin_desc}\n\n例如能够处理如下问题:\n{examples_str}""".format( + plugin_desc=description, examples_str=examples_str + ) plugin_config_new["name"] = name plugin_config_new["description"] = description_new @@ -84,4 +96,3 @@ def plugins_config_format_convert(plugin_config_list: List[Mapping[str, Any]]) - plugin_config_list_new.append(plugin_config_new) return plugin_config_list_new - diff --git a/chat/core/src/main/python/plugin_call/run.py b/chat/core/src/main/python/plugin_call/run.py index 88b629ba5..769abf438 100644 --- a/chat/core/src/main/python/plugin_call/run.py +++ b/chat/core/src/main/python/plugin_call/run.py @@ -10,12 +10,19 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from plugin_call.prompt_construct import construct_plugin_pool_prompt, construct_task_prompt, plugin_selection_output_parse, plugins_config_format_convert +from plugin_call.prompt_construct import ( + construct_plugin_pool_prompt, + construct_task_prompt, + plugin_selection_output_parse, + plugins_config_format_convert, +) from util.llm_instance import llm -def plugin_selection_run(query_text: str, plugin_configs: List[Mapping[str, Any]])-> Union[Mapping[str, str], None]: - +def plugin_selection_run( + query_text: str, plugin_configs: List[Mapping[str, Any]] +) -> Union[Mapping[str, str], None]: + tools_prompt = construct_plugin_pool_prompt(plugin_configs) task_prompt = construct_task_prompt(query_text, tools_prompt) @@ -23,4 +30,3 @@ def plugin_selection_run(query_text: str, plugin_configs: List[Mapping[str, Any] parsed_output = plugin_selection_output_parse(llm_output) return parsed_output - diff --git a/chat/core/src/main/python/preset_retrieval/preset_query_db.py b/chat/core/src/main/python/preset_retrieval/preset_query_db.py index 837e58774..8a7c1c554 100644 --- a/chat/core/src/main/python/preset_retrieval/preset_query_db.py +++ b/chat/core/src/main/python/preset_retrieval/preset_query_db.py @@ -11,7 +11,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -def get_ids(documents:List[str]) -> List[str]: +def get_ids(documents: List[str]) -> List[str]: ids = [] for doc in documents: ids.append(str(uuid.uuid5(uuid.NAMESPACE_URL, doc))) @@ -19,25 +19,23 @@ def get_ids(documents:List[str]) -> List[str]: return ids -def add2preset_query_collection(collection:Collection, - preset_queries:List[str], - preset_query_ids:List[str] - ) -> None: +def add2preset_query_collection( + collection: Collection, preset_queries: List[str], preset_query_ids: List[str] +) -> None: - collection.add(documents=preset_queries, - ids=preset_query_ids) + collection.add(documents=preset_queries, ids=preset_query_ids) -def update_preset_query_collection(collection:Collection, - preset_queries:List[str], - preset_query_ids:List[str] - ) -> None: - - collection.update(documents=preset_queries, - ids=preset_query_ids) - +def update_preset_query_collection( + collection: Collection, preset_queries: List[str], preset_query_ids: List[str] +) -> None: -def query2preset_query_collection(collection:Collection, query_texts:List[str], n_results:int=10): + collection.update(documents=preset_queries, ids=preset_query_ids) + + +def query2preset_query_collection( + collection: Collection, query_texts: List[str], n_results: int = 10 +): collection_cnt = collection.count() min_n_results = 10 min_n_results = min(collection_cnt, min_n_results) @@ -56,12 +54,13 @@ def query2preset_query_collection(collection:Collection, query_texts:List[str], return res -def parse_retrieval_preset_query(res:List[Mapping[str, Any]]): - parsed_res = [[] for _ in range(0, len(res['ids']))] - retrieval_ids = res['ids'] - retrieval_distances = res['distances'] - retrieval_sentences = res['documents'] +def parse_retrieval_preset_query(res: List[Mapping[str, Any]]): + parsed_res = [[] for _ in range(0, len(res["ids"]))] + + retrieval_ids = res["ids"] + retrieval_distances = res["distances"] + retrieval_sentences = res["documents"] for query_idx in range(0, len(retrieval_ids)): id_ls = retrieval_ids[query_idx] @@ -73,43 +72,41 @@ def parse_retrieval_preset_query(res:List[Mapping[str, Any]]): distance = distance_ls[idx] sentence = sentence_ls[idx] - parsed_res[query_idx].append({ - 'id': id, - 'distance': distance, - 'presetQuery': sentence - }) + parsed_res[query_idx].append( + {"id": id, "distance": distance, "presetQuery": sentence} + ) return parsed_res -def preset_query_retrieval_format(query_list:List[str], retrieval_list:List[Mapping[str, Any]]): + +def preset_query_retrieval_format( + query_list: List[str], retrieval_list: List[Mapping[str, Any]] +): res = [] for query_idx in range(0, len(query_list)): query = query_list[query_idx] retrieval = retrieval_list[query_idx] - res.append({ - 'query': query, - 'retrieval': retrieval - }) + res.append({"query": query, "retrieval": retrieval}) return res -def empty_preset_query_collection(collection:Collection) -> None: + +def empty_preset_query_collection(collection: Collection) -> None: collection.delete() -def delete_preset_query_by_ids(collection:Collection, preset_query_ids:List[str]) -> None: + +def delete_preset_query_by_ids( + collection: Collection, preset_query_ids: List[str] +) -> None: collection.delete(ids=preset_query_ids) -def get_preset_query_by_ids(collection:Collection, preset_query_ids:List[str]): + +def get_preset_query_by_ids(collection: Collection, preset_query_ids: List[str]): res = collection.get(ids=preset_query_ids) return res -def preset_query_collection_size(collection:Collection) -> int: + +def preset_query_collection_size(collection: Collection) -> int: return collection.count() - - - - - - diff --git a/chat/core/src/main/python/preset_retrieval/run.py b/chat/core/src/main/python/preset_retrieval/run.py index dc501b49c..e0d7ce0d0 100644 --- a/chat/core/src/main/python/preset_retrieval/run.py +++ b/chat/core/src/main/python/preset_retrieval/run.py @@ -13,34 +13,45 @@ from chromadb.api import Collection, Documents, Embeddings from langchain.llms import OpenAI -from preset_query_db import (get_ids, add2preset_query_collection, - query2preset_query_collection, parse_retrieval_preset_query, - preset_query_retrieval_format, empty_preset_query_collection, preset_query_collection_size) +from preset_query_db import ( + get_ids, + add2preset_query_collection, + query2preset_query_collection, + parse_retrieval_preset_query, + preset_query_retrieval_format, + empty_preset_query_collection, + preset_query_collection_size, +) from util.text2vec import Text2VecEmbeddingFunction from run_config import CHROMA_DB_PERSIST_PATH, PRESET_QUERY_COLLECTION_NAME -from util.chromadb_instance import client +from util.chromadb_instance import client emb_func = Text2VecEmbeddingFunction() -collection = client.get_or_create_collection(name=PRESET_QUERY_COLLECTION_NAME, - embedding_function=emb_func, - metadata={"hnsw:space": "cosine"} - ) # Get a collection object from an existing collection, by name. If it doesn't exist, create it. +collection = client.get_or_create_collection( + name=PRESET_QUERY_COLLECTION_NAME, + embedding_function=emb_func, + metadata={"hnsw:space": "cosine"}, +) # Get a collection object from an existing collection, by name. If it doesn't exist, create it. print("init_preset_query_collection_size: ", preset_query_collection_size(collection)) -def preset_query_retrieval_run(collection:Collection, query_texts_list:List[str], n_results:int=5): - retrieval_res = query2preset_query_collection(collection=collection, - query_texts=query_texts_list, - n_results=n_results) +def preset_query_retrieval_run( + collection: Collection, query_texts_list: List[str], n_results: int = 5 +): + retrieval_res = query2preset_query_collection( + collection=collection, query_texts=query_texts_list, n_results=n_results + ) parsed_retrieval_res = parse_retrieval_preset_query(retrieval_res) - parsed_retrieval_res_format = preset_query_retrieval_format(query_texts_list, parsed_retrieval_res) + parsed_retrieval_res_format = preset_query_retrieval_format( + query_texts_list, parsed_retrieval_res + ) - print('parsed_retrieval_res_format: ', parsed_retrieval_res_format) + print("parsed_retrieval_res_format: ", parsed_retrieval_res_format) return parsed_retrieval_res_format diff --git a/chat/core/src/main/python/run_config.py b/chat/core/src/main/python/run_config.py index 2d4cbaf53..0b909b77d 100644 --- a/chat/core/src/main/python/run_config.py +++ b/chat/core/src/main/python/run_config.py @@ -11,7 +11,7 @@ OPENAI_API_KEY = "YOUR_API_KEY" TEMPERATURE = 0.0 -CHROMA_DB_PERSIST_DIR = 'chm_db' +CHROMA_DB_PERSIST_DIR = "chm_db" PRESET_QUERY_COLLECTION_NAME = "preset_query_collection" TEXT2DSL_COLLECTION_NAME = "text2dsl_collection" TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM = 15 @@ -21,9 +21,9 @@ CHROMA_DB_PERSIST_PATH = os.path.join(PROJECT_DIR_PATH, CHROMA_DB_PERSIST_DIR) HF_TEXT2VEC_MODEL_NAME = "GanymedeNil/text2vec-large-chinese" -if __name__ == '__main__': - print('PROJECT_DIR_PATH: ', PROJECT_DIR_PATH) - print('EMB_MODEL_PATH: ', HF_TEXT2VEC_MODEL_NAME) - print('CHROMA_DB_PERSIST_PATH: ', CHROMA_DB_PERSIST_PATH) - print('LLMPARSER_HOST: ', LLMPARSER_HOST) - print('LLMPARSER_PORT: ', LLMPARSER_PORT) \ No newline at end of file +if __name__ == "__main__": + print("PROJECT_DIR_PATH: ", PROJECT_DIR_PATH) + print("EMB_MODEL_PATH: ", HF_TEXT2VEC_MODEL_NAME) + print("CHROMA_DB_PERSIST_PATH: ", CHROMA_DB_PERSIST_PATH) + print("LLMPARSER_HOST: ", LLMPARSER_HOST) + print("LLMPARSER_PORT: ", LLMPARSER_PORT) diff --git a/chat/core/src/main/python/sql/constructor.py b/chat/core/src/main/python/sql/constructor.py index 2553e4eca..0fb84ce5c 100644 --- a/chat/core/src/main/python/sql/constructor.py +++ b/chat/core/src/main/python/sql/constructor.py @@ -22,20 +22,34 @@ from util.text2vec import Text2VecEmbeddingFunction, hg_embedding from util.chromadb_instance import client as chromadb_client, empty_chroma_collection_2 from run_config import TEXT2DSL_COLLECTION_NAME, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM -def reload_sql_example_collection(vectorstore:Chroma, - sql_examplars:List[Mapping[str, str]], - sql_example_selector:SemanticSimilarityExampleSelector, - example_nums:int - ): + +def reload_sql_example_collection( + vectorstore: Chroma, + sql_examplars: List[Mapping[str, str]], + sql_example_selector: SemanticSimilarityExampleSelector, + example_nums: int, +): print("original sql_examples_collection size:", vectorstore._collection.count()) new_collection = empty_chroma_collection_2(collection=vectorstore._collection) vectorstore._collection = new_collection print("emptied sql_examples_collection size:", vectorstore._collection.count()) - sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"]) + sql_example_selector = SemanticSimilarityExampleSelector( + vectorstore=sql_examples_vectorstore, + k=example_nums, + input_keys=["question"], + example_keys=[ + "table_name", + "fields_list", + "prior_schema_links", + "question", + "analysis", + "schema_links", + "current_date", + "sql", + ], + ) for example in sql_examplars: sql_example_selector.add_example(example) @@ -45,20 +59,36 @@ def reload_sql_example_collection(vectorstore:Chroma, return vectorstore, sql_example_selector -sql_examples_vectorstore = Chroma(collection_name=TEXT2DSL_COLLECTION_NAME, - embedding_function=hg_embedding, - client=chromadb_client) +sql_examples_vectorstore = Chroma( + collection_name=TEXT2DSL_COLLECTION_NAME, + embedding_function=hg_embedding, + client=chromadb_client, +) example_nums = TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM -sql_example_selector = SemanticSimilarityExampleSelector(vectorstore=sql_examples_vectorstore, k=example_nums, - input_keys=["question"], - example_keys=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links", "current_date", "sql"]) +sql_example_selector = SemanticSimilarityExampleSelector( + vectorstore=sql_examples_vectorstore, + k=example_nums, + input_keys=["question"], + example_keys=[ + "table_name", + "fields_list", + "prior_schema_links", + "question", + "analysis", + "schema_links", + "current_date", + "sql", + ], +) if sql_examples_vectorstore._collection.count() > 0: print("examples already in sql_vectorstore") - print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count()) + print("init sql_vectorstore size:", sql_examples_vectorstore._collection.count()) print("sql_examplars size:", len(sql_examplars)) -sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection(sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums) +sql_examples_vectorstore, sql_example_selector = reload_sql_example_collection( + sql_examples_vectorstore, sql_examplars, sql_example_selector, example_nums +) print("added sql_vectorstore size:", sql_examples_vectorstore._collection.count()) diff --git a/chat/core/src/main/python/sql/examples_reload_run.py b/chat/core/src/main/python/sql/examples_reload_run.py index 65df9087d..6a0efec2b 100644 --- a/chat/core/src/main/python/sql/examples_reload_run.py +++ b/chat/core/src/main/python/sql/examples_reload_run.py @@ -13,17 +13,31 @@ from few_shot_example.sql_exampler import examplars as sql_examplars from run_config import LLMPARSER_HOST, LLMPARSER_PORT -def text2dsl_setting_update(llm_parser_host:str, llm_parser_port:str, - sql_examplars:List[Mapping[str, str]], example_nums:int, is_shortcut:bool): +def text2dsl_setting_update( + llm_parser_host: str, + llm_parser_port: str, + sql_examplars: List[Mapping[str, str]], + example_nums: int, + is_shortcut: bool, +): url = f"http://{llm_parser_host}:{llm_parser_port}/query2sql_setting_update/" print("url: ", url) - payload = {"sqlExamplars":sql_examplars, "exampleNums":example_nums, "isShortcut":is_shortcut} - headers = {'content-type': 'application/json'} + payload = { + "sqlExamplars": sql_examplars, + "exampleNums": example_nums, + "isShortcut": is_shortcut, + } + headers = {"content-type": "application/json"} response = requests.post(url, data=json.dumps(payload), headers=headers) print(response.text) if __name__ == "__main__": - text2dsl_setting_update(LLMPARSER_HOST, LLMPARSER_PORT, - sql_examplars, TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, TEXT2DSL_IS_SHORTCUT) + text2dsl_setting_update( + LLMPARSER_HOST, + LLMPARSER_PORT, + sql_examplars, + TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM, + TEXT2DSL_IS_SHORTCUT, + ) diff --git a/chat/core/src/main/python/sql/output_parser.py b/chat/core/src/main/python/sql/output_parser.py index aa0ff317f..eeeb6bce0 100644 --- a/chat/core/src/main/python/sql/output_parser.py +++ b/chat/core/src/main/python/sql/output_parser.py @@ -1,21 +1,25 @@ # -*- coding:utf-8 -*- import re + def schema_link_parse(schema_link_output): try: schema_link_output = schema_link_output.strip() - pattern = r'Schema_links:(.*)' - schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[0].strip() + pattern = r"Schema_links:(.*)" + schema_link_output = re.findall(pattern, schema_link_output, re.DOTALL)[ + 0 + ].strip() except Exception as e: print(e) schema_link_output = None return schema_link_output + def combo_schema_link_parse(schema_linking_sql_combo_output: str): try: schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip() - pattern = r'Schema_links:(\[.*?\])' + pattern = r"Schema_links:(\[.*?\])" schema_links_match = re.search(pattern, schema_linking_sql_combo_output) if schema_links_match: @@ -28,10 +32,11 @@ def combo_schema_link_parse(schema_linking_sql_combo_output: str): return schema_links + def combo_sql_parse(schema_linking_sql_combo_output: str): try: schema_linking_sql_combo_output = schema_linking_sql_combo_output.strip() - pattern = r'SQL:(.*)' + pattern = r"SQL:(.*)" sql_match = re.search(pattern, schema_linking_sql_combo_output) if sql_match: diff --git a/chat/core/src/main/python/sql/prompt_maker.py b/chat/core/src/main/python/sql/prompt_maker.py index 7c4f5fccc..925cd026e 100644 --- a/chat/core/src/main/python/sql/prompt_maker.py +++ b/chat/core/src/main/python/sql/prompt_maker.py @@ -11,17 +11,31 @@ from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.example_selector import SemanticSimilarityExampleSelector -def schema_linking_exampler(user_query: str, - domain_name: str, - fields_list: List[str], - prior_schema_links: Mapping[str,str], - example_selector: SemanticSimilarityExampleSelector, - ) -> str: +def schema_linking_exampler( + user_query: str, + domain_name: str, + fields_list: List[str], + prior_schema_links: Mapping[str, str], + example_selector: SemanticSimilarityExampleSelector, +) -> str: - prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' + prior_schema_links_str = ( + "[" + + ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()]) + + "]" + ) - example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "question", "analysis", "schema_links"], - template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}") + example_prompt_template = PromptTemplate( + input_variables=[ + "table_name", + "fields_list", + "prior_schema_links", + "question", + "analysis", + "schema_links", + ], + template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}", + ) instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links" @@ -30,81 +44,121 @@ def schema_linking_exampler(user_query: str, schema_linking_example_prompt_template = FewShotPromptTemplate( example_selector=example_selector, example_prompt=example_prompt_template, - example_separator="\n\n", + example_separator="\n\n", prefix=instruction, input_variables=["table_name", "fields_list", "prior_schema_links", "question"], - suffix=schema_linking_prompt - ) + suffix=schema_linking_prompt, + ) - schema_linking_example_prompt = schema_linking_example_prompt_template.format(table_name=domain_name, - fields_list=fields_list, - prior_schema_links=prior_schema_links_str, - question=user_query) + schema_linking_example_prompt = schema_linking_example_prompt_template.format( + table_name=domain_name, + fields_list=fields_list, + prior_schema_links=prior_schema_links_str, + question=user_query, + ) return schema_linking_example_prompt -def sql_exampler(user_query: str, - domain_name: str, - schema_link_str: str, - data_date: str, - example_selector: SemanticSimilarityExampleSelector, - ) -> str: - +def sql_exampler( + user_query: str, + domain_name: str, + schema_link_str: str, + data_date: str, + example_selector: SemanticSimilarityExampleSelector, +) -> str: + instruction = "# 根据schema_links为每个问题生成SQL查询语句" - sql_example_prompt_template = PromptTemplate(input_variables=["question", "current_date", "table_name", "schema_links", "sql"], - template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}") + sql_example_prompt_template = PromptTemplate( + input_variables=[ + "question", + "current_date", + "table_name", + "schema_links", + "sql", + ], + template="问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:{sql}", + ) sql_prompt = "问题:{question}\nCurrent_date:{current_date}\nTable {table_name}\nSchema_links:{schema_links}\nSQL:" sql_example_prompt_template = FewShotPromptTemplate( example_selector=example_selector, example_prompt=sql_example_prompt_template, - example_separator="\n\n", + example_separator="\n\n", prefix=instruction, input_variables=["question", "current_date", "table_name", "schema_links"], - suffix=sql_prompt - ) + suffix=sql_prompt, + ) - sql_example_prompt = sql_example_prompt_template.format(question=user_query, - current_date=data_date, - table_name=domain_name, - schema_links=schema_link_str) + sql_example_prompt = sql_example_prompt_template.format( + question=user_query, + current_date=data_date, + table_name=domain_name, + schema_links=schema_link_str, + ) return sql_example_prompt -def schema_linking_sql_combo_examplar(user_query: str, - domain_name: str, - data_date : str, - fields_list: List[str], - prior_schema_links: Mapping[str,str], - example_selector: SemanticSimilarityExampleSelector) -> str: - - prior_schema_links_str = '['+ ','.join(["""'{}'->{}""".format(k,v) for k,v in prior_schema_links.items()]) + ']' +def schema_linking_sql_combo_examplar( + user_query: str, + domain_name: str, + data_date: str, + fields_list: List[str], + prior_schema_links: Mapping[str, str], + example_selector: SemanticSimilarityExampleSelector, +) -> str: - example_prompt_template = PromptTemplate(input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question", "analysis", "schema_links", "sql"], - template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}") + prior_schema_links_str = ( + "[" + + ",".join(["""'{}'->{}""".format(k, v) for k, v in prior_schema_links.items()]) + + "]" + ) - instruction = "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句" + example_prompt_template = PromptTemplate( + input_variables=[ + "table_name", + "fields_list", + "prior_schema_links", + "current_date", + "question", + "analysis", + "schema_links", + "sql", + ], + template="Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析:{analysis} 所以Schema_links是:\nSchema_links:{schema_links}\nSQL:{sql}", + ) + + instruction = ( + "# 根据数据库的表结构,参考先验信息,找出为每个问题生成SQL查询语句的schema_links,再根据schema_links为每个问题生成SQL查询语句" + ) schema_linking_sql_combo_prompt = "Table {table_name}, columns = {fields_list}, prior_schema_links = {prior_schema_links}\nCurrent_date:{current_date}\n问题:{question}\n分析: 让我们一步一步地思考。" schema_linking_sql_combo_example_prompt_template = FewShotPromptTemplate( example_selector=example_selector, example_prompt=example_prompt_template, - example_separator="\n\n", + example_separator="\n\n", prefix=instruction, - input_variables=["table_name", "fields_list", "prior_schema_links", "current_date", "question"], - suffix=schema_linking_sql_combo_prompt + input_variables=[ + "table_name", + "fields_list", + "prior_schema_links", + "current_date", + "question", + ], + suffix=schema_linking_sql_combo_prompt, + ) + + schema_linking_sql_combo_example_prompt = ( + schema_linking_sql_combo_example_prompt_template.format( + table_name=domain_name, + fields_list=fields_list, + prior_schema_links=prior_schema_links_str, + current_date=data_date, + question=user_query, ) - - schema_linking_sql_combo_example_prompt = schema_linking_sql_combo_example_prompt_template.format(table_name=domain_name, - fields_list=fields_list, - prior_schema_links=prior_schema_links_str, - current_date=data_date, - question=user_query) + ) return schema_linking_sql_combo_example_prompt - - diff --git a/chat/core/src/main/python/sql/run.py b/chat/core/src/main/python/sql/run.py index 02931b5c8..ddd4be8f3 100644 --- a/chat/core/src/main/python/sql/run.py +++ b/chat/core/src/main/python/sql/run.py @@ -7,133 +7,182 @@ import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from sql.prompt_maker import schema_linking_exampler, sql_exampler, schema_linking_sql_combo_examplar -from sql.constructor import sql_examples_vectorstore, sql_example_selector, reload_sql_example_collection -from sql.output_parser import schema_link_parse, combo_schema_link_parse, combo_sql_parse +from sql.prompt_maker import ( + schema_linking_exampler, + sql_exampler, + schema_linking_sql_combo_examplar, +) +from sql.constructor import ( + sql_examples_vectorstore, + sql_example_selector, + reload_sql_example_collection, +) +from sql.output_parser import ( + schema_link_parse, + combo_schema_link_parse, + combo_sql_parse, +) from util.llm_instance import llm from run_config import TEXT2DSL_IS_SHORTCUT + class Text2DSLAgent(object): - def __init__(self): - self.schema_linking_exampler = schema_linking_exampler - self.sql_exampler = sql_exampler + def __init__(self): + self.schema_linking_exampler = schema_linking_exampler + self.sql_exampler = sql_exampler - self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar + self.schema_linking_sql_combo_exampler = schema_linking_sql_combo_examplar - self.sql_examples_vectorstore = sql_examples_vectorstore - self.sql_example_selector = sql_example_selector + self.sql_examples_vectorstore = sql_examples_vectorstore + self.sql_example_selector = sql_example_selector - self.schema_link_parse = schema_link_parse - self.combo_schema_link_parse = combo_schema_link_parse - self.combo_sql_parse = combo_sql_parse + self.schema_link_parse = schema_link_parse + self.combo_schema_link_parse = combo_schema_link_parse + self.combo_sql_parse = combo_sql_parse - self.llm = llm + self.llm = llm - self.is_shortcut = TEXT2DSL_IS_SHORTCUT + self.is_shortcut = TEXT2DSL_IS_SHORTCUT - def update_examples(self, sql_examples, example_nums, is_shortcut): - self.sql_examples_vectorstore, self.sql_example_selector = reload_sql_example_collection(self.sql_examples_vectorstore, - sql_examples, - self.sql_example_selector, - example_nums) - self.is_shortcut = is_shortcut + def update_examples(self, sql_examples, example_nums, is_shortcut): + ( + self.sql_examples_vectorstore, + self.sql_example_selector, + ) = reload_sql_example_collection( + self.sql_examples_vectorstore, + sql_examples, + self.sql_example_selector, + example_nums, + ) + self.is_shortcut = is_shortcut - def query2sql(self, query_text: str, - schema : Union[dict, None] = None, - current_date: str = None, - linking: Union[List[Mapping[str, str]], None] = None - ): + def query2sql( + self, + query_text: str, + schema: Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None, + ): - print("query_text: ", query_text) - print("schema: ", schema) - print("current_date: ", current_date) - print("prior_schema_links: ", linking) + print("query_text: ", query_text) + print("schema: ", schema) + print("current_date: ", current_date) + print("prior_schema_links: ", linking) - if linking is not None: - prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking} - else: - prior_schema_links = {} + if linking is not None: + prior_schema_links = { + item["fieldValue"]: item["fieldName"] for item in linking + } + else: + prior_schema_links = {} - model_name = schema['modelName'] - fields_list = schema['fieldNameList'] + model_name = schema["modelName"] + fields_list = schema["fieldNameList"] - schema_linking_prompt = self.schema_linking_exampler(query_text, model_name, fields_list, prior_schema_links, self.sql_example_selector) - print("schema_linking_prompt->", schema_linking_prompt) - schema_link_output = self.llm(schema_linking_prompt) - schema_link_str = self.schema_link_parse(schema_link_output) - - sql_prompt = self.sql_exampler(query_text, model_name, schema_link_str, current_date, self.sql_example_selector) - print("sql_prompt->", sql_prompt) - sql_output = self.llm(sql_prompt) + schema_linking_prompt = self.schema_linking_exampler( + query_text, + model_name, + fields_list, + prior_schema_links, + self.sql_example_selector, + ) + print("schema_linking_prompt->", schema_linking_prompt) + schema_link_output = self.llm(schema_linking_prompt) + schema_link_str = self.schema_link_parse(schema_link_output) - resp = dict() - resp['query'] = query_text - resp['model'] = model_name - resp['fields'] = fields_list - resp['priorSchemaLinking'] = linking - resp['dataDate'] = current_date + sql_prompt = self.sql_exampler( + query_text, + model_name, + schema_link_str, + current_date, + self.sql_example_selector, + ) + print("sql_prompt->", sql_prompt) + sql_output = self.llm(sql_prompt) - resp['analysisOutput'] = schema_link_output - resp['schemaLinkStr'] = schema_link_str - - resp['sqlOutput'] = sql_output + resp = dict() + resp["query"] = query_text + resp["model"] = model_name + resp["fields"] = fields_list + resp["priorSchemaLinking"] = linking + resp["dataDate"] = current_date - print("resp: ", resp) + resp["analysisOutput"] = schema_link_output + resp["schemaLinkStr"] = schema_link_str - return resp + resp["sqlOutput"] = sql_output - def query2sqlcombo(self, query_text: str, - schema : Union[dict, None] = None, - current_date: str = None, - linking: Union[List[Mapping[str, str]], None] = None - ): + print("resp: ", resp) - print("query_text: ", query_text) - print("schema: ", schema) - print("current_date: ", current_date) - print("prior_schema_links: ", linking) + return resp - if linking is not None: - prior_schema_links = {item['fieldValue']:item['fieldName'] for item in linking} - else: - prior_schema_links = {} + def query2sqlcombo( + self, + query_text: str, + schema: Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None, + ): - model_name = schema['modelName'] - fields_list = schema['fieldNameList'] + print("query_text: ", query_text) + print("schema: ", schema) + print("current_date: ", current_date) + print("prior_schema_links: ", linking) - schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler(query_text, model_name, current_date, fields_list, - prior_schema_links, self.sql_example_selector) - print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt) - schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) + if linking is not None: + prior_schema_links = { + item["fieldValue"]: item["fieldName"] for item in linking + } + else: + prior_schema_links = {} - schema_linking_str = self.combo_schema_link_parse(schema_linking_sql_combo_output) - sql_str = self.combo_sql_parse(schema_linking_sql_combo_output) + model_name = schema["modelName"] + fields_list = schema["fieldNameList"] - resp = dict() - resp['query'] = query_text - resp['model'] = model_name - resp['fields'] = fields_list - resp['priorSchemaLinking'] = prior_schema_links - resp['dataDate'] = current_date + schema_linking_sql_combo_prompt = self.schema_linking_sql_combo_exampler( + query_text, + model_name, + current_date, + fields_list, + prior_schema_links, + self.sql_example_selector, + ) + print("schema_linking_sql_combo_prompt->", schema_linking_sql_combo_prompt) + schema_linking_sql_combo_output = self.llm(schema_linking_sql_combo_prompt) - resp['analysisOutput'] = schema_linking_sql_combo_output - resp['schemaLinkStr'] = schema_linking_str - resp['sqlOutput'] = sql_str + schema_linking_str = self.combo_schema_link_parse( + schema_linking_sql_combo_output + ) + sql_str = self.combo_sql_parse(schema_linking_sql_combo_output) - print("resp: ", resp) + resp = dict() + resp["query"] = query_text + resp["model"] = model_name + resp["fields"] = fields_list + resp["priorSchemaLinking"] = prior_schema_links + resp["dataDate"] = current_date - return resp + resp["analysisOutput"] = schema_linking_sql_combo_output + resp["schemaLinkStr"] = schema_linking_str + resp["sqlOutput"] = sql_str - def query2sql_run(self, query_text: str, - schema : Union[dict, None] = None, - current_date: str = None, - linking: Union[List[Mapping[str, str]], None] = None): + print("resp: ", resp) + + return resp + + def query2sql_run( + self, + query_text: str, + schema: Union[dict, None] = None, + current_date: str = None, + linking: Union[List[Mapping[str, str]], None] = None, + ): + + if self.is_shortcut: + return self.query2sqlcombo(query_text, schema, current_date, linking) + else: + return self.query2sql(query_text, schema, current_date, linking) - if self.is_shortcut: - return self.query2sqlcombo(query_text, schema, current_date, linking) - else: - return self.query2sql(query_text, schema, current_date, linking) text2sql_agent = Text2DSLAgent() - diff --git a/chat/core/src/main/python/supersonic_llmparser.py b/chat/core/src/main/python/supersonic_llmparser.py index 40ebfe613..0eb984cb4 100644 --- a/chat/core/src/main/python/supersonic_llmparser.py +++ b/chat/core/src/main/python/supersonic_llmparser.py @@ -13,11 +13,19 @@ from fastapi import FastAPI, HTTPException from sql.run import text2sql_agent -from preset_retrieval.run import preset_query_retrieval_run, collection as preset_query_collection -from preset_retrieval.preset_query_db import (add2preset_query_collection, update_preset_query_collection, - empty_preset_query_collection, delete_preset_query_by_ids, - update_preset_query_collection, get_preset_query_by_ids, - preset_query_collection_size) +from preset_retrieval.run import ( + preset_query_retrieval_run, + collection as preset_query_collection, +) +from preset_retrieval.preset_query_db import ( + add2preset_query_collection, + update_preset_query_collection, + empty_preset_query_collection, + delete_preset_query_by_ids, + update_preset_query_collection, + get_preset_query_by_ids, + preset_query_collection_size, +) from plugin_call.run import plugin_selection_run @@ -27,62 +35,64 @@ from run_config import LLMPARSER_PORT app = FastAPI() - @app.post("/query2sql/") async def din_query2sql(query_body: Mapping[str, Any]): - if 'queryText' not in query_body: - raise HTTPException(status_code=400, - detail="query_text is not in query_body") + if "queryText" not in query_body: + raise HTTPException(status_code=400, detail="query_text is not in query_body") else: - query_text = query_body['queryText'] + query_text = query_body["queryText"] - if 'schema' not in query_body: + if "schema" not in query_body: raise HTTPException(status_code=400, detail="schema is not in query_body") else: - schema = query_body['schema'] + schema = query_body["schema"] - if 'currentDate' not in query_body: + if "currentDate" not in query_body: raise HTTPException(status_code=400, detail="currentDate is not in query_body") else: - current_date = query_body['currentDate'] + current_date = query_body["currentDate"] - if 'linking' not in query_body: + if "linking" not in query_body: linking = None else: - linking = query_body['linking'] + linking = query_body["linking"] - resp = text2sql_agent.query2sql_run(query_text=query_text, - schema=schema, current_date=current_date, linking=linking) + resp = text2sql_agent.query2sql_run( + query_text=query_text, schema=schema, current_date=current_date, linking=linking + ) return resp @app.post("/query2sql_setting_update/") async def query2sql_setting_update(query_body: Mapping[str, Any]): - if 'sqlExamplars' not in query_body: - raise HTTPException(status_code=400, - detail="sqlExamplars is not in query_body") + if "sqlExamplars" not in query_body: + raise HTTPException(status_code=400, detail="sqlExamplars is not in query_body") else: - sql_examplars = query_body['sqlExamplars'] + sql_examplars = query_body["sqlExamplars"] - if 'exampleNums' not in query_body: + if "exampleNums" not in query_body: raise HTTPException(status_code=400, detail="exampleNums is not in query_body") else: - example_nums = query_body['exampleNums'] + example_nums = query_body["exampleNums"] - if 'isShortcut' not in query_body: + if "isShortcut" not in query_body: raise HTTPException(status_code=400, detail="isShortcut is not in query_body") else: - is_shortcut = query_body['isShortcut'] + is_shortcut = query_body["isShortcut"] - text2sql_agent.update_examples(sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut) + text2sql_agent.update_examples( + sql_examples=sql_examplars, example_nums=example_nums, is_shortcut=is_shortcut + ) return "success" @app.post("/preset_query_retrival/") async def preset_query_retrival(query_text_list: List[str], n_results: int = 5): - parsed_retrieval_res_format = preset_query_retrieval_run(preset_query_collection, query_text_list, n_results) + parsed_retrieval_res_format = preset_query_retrieval_run( + preset_query_collection, query_text_list, n_results + ) return parsed_retrieval_res_format @@ -93,27 +103,32 @@ async def preset_query_add(preset_info_list: List[Mapping[str, str]]): preset_query_ids = [] for preset_info in preset_info_list: - preset_queries.append(preset_info['preset_query']) - preset_query_ids.append(preset_info['preset_query_id']) + preset_queries.append(preset_info["preset_query"]) + preset_query_ids.append(preset_info["preset_query_id"]) - add2preset_query_collection(collection=preset_query_collection, - preset_queries=preset_queries, - preset_query_ids=preset_query_ids) + add2preset_query_collection( + collection=preset_query_collection, + preset_queries=preset_queries, + preset_query_ids=preset_query_ids, + ) return "success" + @app.post("/preset_query_update/") async def preset_query_update(preset_info_list: List[Mapping[str, str]]): preset_queries = [] preset_query_ids = [] for preset_info in preset_info_list: - preset_queries.append(preset_info['preset_query']) - preset_query_ids.append(preset_info['preset_query_id']) + preset_queries.append(preset_info["preset_query"]) + preset_query_ids.append(preset_info["preset_query_id"]) - update_preset_query_collection(collection=preset_query_collection, - preset_queries=preset_queries, - preset_query_ids=preset_query_ids) + update_preset_query_collection( + collection=preset_query_collection, + preset_queries=preset_queries, + preset_query_ids=preset_query_ids, + ) return "success" @@ -124,39 +139,50 @@ async def preset_query_empty(): return "success" + @app.post("/preset_delete_by_ids/") async def preset_delete_by_ids(preset_query_ids: List[str]): - delete_preset_query_by_ids(collection=preset_query_collection, preset_query_ids=preset_query_ids) + delete_preset_query_by_ids( + collection=preset_query_collection, preset_query_ids=preset_query_ids + ) return "success" + @app.post("/preset_get_by_ids/") async def preset_get_by_ids(preset_query_ids: List[str]): - preset_queries = get_preset_query_by_ids(collection=preset_query_collection, preset_query_ids=preset_query_ids) + preset_queries = get_preset_query_by_ids( + collection=preset_query_collection, preset_query_ids=preset_query_ids + ) return preset_queries + @app.get("/preset_query_size/") async def preset_query_size(): size = preset_query_collection_size(collection=preset_query_collection) return size + @app.post("/plugin_selection/") async def tool_selection(query_body: Mapping[str, Any]): - if 'queryText' not in query_body: + if "queryText" not in query_body: raise HTTPException(status_code=400, detail="query_text is not in query_body") else: - query_text = query_body['queryText'] + query_text = query_body["queryText"] - if 'pluginConfigs' not in query_body: - raise HTTPException(status_code=400, detail="pluginConfigs is not in query_body") + if "pluginConfigs" not in query_body: + raise HTTPException( + status_code=400, detail="pluginConfigs is not in query_body" + ) else: - plugin_configs = query_body['pluginConfigs'] + plugin_configs = query_body["pluginConfigs"] resp = plugin_selection_run(query_text=query_text, plugin_configs=plugin_configs) return resp + if __name__ == "__main__": uvicorn.run(app, host=LLMPARSER_HOST, port=LLMPARSER_PORT) diff --git a/chat/core/src/main/python/util/chromadb_instance.py b/chat/core/src/main/python/util/chromadb_instance.py index ac2af57ca..0ee4920e8 100644 --- a/chat/core/src/main/python/util/chromadb_instance.py +++ b/chat/core/src/main/python/util/chromadb_instance.py @@ -7,13 +7,15 @@ from chromadb.config import Settings from run_config import CHROMA_DB_PERSIST_PATH -client = chromadb.Client(Settings( - chroma_db_impl="duckdb+parquet", - persist_directory=CHROMA_DB_PERSIST_PATH # Optional, defaults to .chromadb/ in the current directory -)) +client = chromadb.Client( + Settings( + chroma_db_impl="duckdb+parquet", + persist_directory=CHROMA_DB_PERSIST_PATH, # Optional, defaults to .chromadb/ in the current directory + ) +) -def empty_chroma_collection_2(collection:Collection): +def empty_chroma_collection_2(collection: Collection): collection_name = collection.name client = collection._client metadata = collection.metadata @@ -21,17 +23,18 @@ def empty_chroma_collection_2(collection:Collection): client.delete_collection(collection_name) - new_collection = client.get_or_create_collection(name=collection_name, - metadata=metadata, - embedding_function=embedding_function) + new_collection = client.get_or_create_collection( + name=collection_name, metadata=metadata, embedding_function=embedding_function + ) size_of_new_collection = new_collection.count() - print(f'Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}') + print( + f"Collection {collection_name} emptied. Size of new collection: {size_of_new_collection}" + ) return new_collection -def empty_chroma_collection(collection:Collection): +def empty_chroma_collection(collection: Collection): collection.delete() - diff --git a/chat/core/src/main/python/util/llm_instance.py b/chat/core/src/main/python/util/llm_instance.py index 97b6a58d6..15a277408 100644 --- a/chat/core/src/main/python/util/llm_instance.py +++ b/chat/core/src/main/python/util/llm_instance.py @@ -4,5 +4,6 @@ from langchain.llms import OpenAI from run_config import MODEL_NAME, OPENAI_API_KEY, TEMPERATURE -llm = OpenAI(openai_api_key=OPENAI_API_KEY, model_name=MODEL_NAME, - temperature=TEMPERATURE) +llm = OpenAI( + openai_api_key=OPENAI_API_KEY, model_name=MODEL_NAME, temperature=TEMPERATURE +) diff --git a/chat/core/src/main/python/util/text2vec.py b/chat/core/src/main/python/util/text2vec.py index fef62f984..ea5f06202 100644 --- a/chat/core/src/main/python/util/text2vec.py +++ b/chat/core/src/main/python/util/text2vec.py @@ -9,6 +9,7 @@ from run_config import HF_TEXT2VEC_MODEL_NAME hg_embedding = HuggingFaceEmbeddings(model_name=HF_TEXT2VEC_MODEL_NAME) + class Text2VecEmbeddingFunction(EmbeddingFunction): def __call__(self, texts: Documents) -> Embeddings: @@ -16,13 +17,8 @@ class Text2VecEmbeddingFunction(EmbeddingFunction): return embeddings -def get_embeddings(documents:List[str]) -> List[List[float]]: + +def get_embeddings(documents: List[str]) -> List[List[float]]: embeddings = hg_embedding.embed_documents(documents) return embeddings - - - - - - diff --git a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml index f965a0882..1cba2a902 100644 --- a/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml +++ b/chat/core/src/main/resources/mapper/ChatQueryDOMapper.xml @@ -3,7 +3,7 @@ - + @@ -77,7 +77,7 @@ query_state, chat_id, score, feedback, query_text, query_result ) - values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=BIGINT}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, + values (#{questionId,jdbcType=BIGINT}, #{agentId,jdbcType=INTEGER}, #{createTime,jdbcType=TIMESTAMP}, #{userName,jdbcType=VARCHAR}, #{queryState,jdbcType=INTEGER}, #{chatId,jdbcType=BIGINT}, #{score,jdbcType=INTEGER}, #{feedback,jdbcType=VARCHAR}, #{queryText,jdbcType=LONGVARCHAR}, #{queryResult,jdbcType=LONGVARCHAR} ) @@ -98,9 +98,6 @@ chat_id = #{chatId,jdbcType=BIGINT}, - - agent_id = #{agentId,jdbcType=INTEGER}, - score = #{score,jdbcType=INTEGER}, @@ -116,5 +113,4 @@ where question_id = #{questionId,jdbcType=BIGINT} - diff --git a/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml b/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml index adaf36822..7dcb1d213 100644 --- a/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml +++ b/chat/core/src/main/resources/mapper/custom/ShowCaseCustomMapper.xml @@ -59,7 +59,7 @@ join ( select distinct chat_id from s2_chat_query - where query_state = 0 and agent_id = ${agentId} + where query_state = 1 and agent_id = ${agentId} order by chat_id desc limit #{start}, #{limit} ) q2 diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java deleted file mode 100644 index 4bc2a3919..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/DateFieldCorrectorTest.java +++ /dev/null @@ -1,45 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import static org.mockito.ArgumentMatchers.any; - -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.parser.llm.dsl.DSLDateHelper; -import org.junit.Assert; -import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; - -class DateFieldCorrectorTest { - - @Test - void corrector() { - MockedStatic dslDateHelper = Mockito.mockStatic(DSLDateHelper.class); - - dslDateHelper.when(() -> DSLDateHelper.getReferenceDate(any())).thenReturn("2023-08-14"); - DateFieldCorrector dateFieldCorrector = new DateFieldCorrector(); - SemanticParseInfo parseInfo = new SemanticParseInfo(); - SchemaElement model = new SchemaElement(); - model.setId(2L); - parseInfo.setModel(model); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(歌曲名) from 歌曲库 ") - .parseInfo(parseInfo) - .build(); - - dateFieldCorrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT count(歌曲名) FROM 歌曲库 WHERE 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql()); - - semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'") - .parseInfo(parseInfo) - .build(); - - dateFieldCorrector.correct(semanticCorrectInfo); - - Assert.assertEquals("select count(歌曲名) from 歌曲库 where 数据日期 = '2023-08-14'", semanticCorrectInfo.getSql()); - - } -} diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java deleted file mode 100644 index 7caae3c06..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldNameCorrectorTest.java +++ /dev/null @@ -1,65 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.parser.llm.dsl.DSLParseResult; -import com.tencent.supersonic.chat.query.llm.dsl.LLMReq; -import com.tencent.supersonic.chat.query.llm.dsl.LLMReq.ElementValue; -import com.tencent.supersonic.common.pojo.Constants; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import org.junit.Assert; -import org.junit.jupiter.api.Test; - -class FieldNameCorrectorTest { - - @Test - void corrector() { - - FieldNameCorrector corrector = new FieldNameCorrector(); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select 歌曲名 from 歌曲库 where 专辑照片 = '七里香' and 专辑名 = '流行' and 数据日期 = '2023-08-19'") - .build(); - - SemanticParseInfo parseInfo = new SemanticParseInfo(); - - DSLParseResult dslParseResult = new DSLParseResult(); - LLMReq llmReq = new LLMReq(); - List linking = new ArrayList<>(); - ElementValue elementValue = new ElementValue(); - elementValue.setFieldValue("流行"); - elementValue.setFieldName("歌曲风格"); - linking.add(elementValue); - - ElementValue elementValue2 = new ElementValue(); - elementValue2.setFieldValue("七里香"); - elementValue2.setFieldName("歌曲名"); - linking.add(elementValue2); - - ElementValue elementValue3 = new ElementValue(); - elementValue3.setFieldValue("周杰伦"); - elementValue3.setFieldName("歌手名"); - linking.add(elementValue3); - - ElementValue elementValue4 = new ElementValue(); - elementValue4.setFieldValue("流行"); - elementValue4.setFieldName("歌曲流派"); - linking.add(elementValue4); - - llmReq.setLinking(linking); - dslParseResult.setLlmReq(llmReq); - - Map properties = new HashMap<>(); - properties.put(Constants.CONTEXT, dslParseResult); - - parseInfo.setProperties(properties); - semanticCorrectInfo.setParseInfo(parseInfo); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT 歌曲名 FROM 歌曲库 WHERE 歌曲名 = '七里香' AND 歌曲流派 = '流行' AND 数据日期 = '2023-08-19'", - semanticCorrectInfo.getSql()); - } -} \ No newline at end of file diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java deleted file mode 100644 index d9afccf23..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/FieldValueCorrectorTest.java +++ /dev/null @@ -1,71 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import static org.mockito.Mockito.when; - -import com.tencent.supersonic.chat.api.pojo.SchemaElement; -import com.tencent.supersonic.chat.api.pojo.SchemaValueMap; -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticParseInfo; -import com.tencent.supersonic.chat.api.pojo.SemanticSchema; -import com.tencent.supersonic.common.util.ContextUtils; -import com.tencent.supersonic.knowledge.service.SchemaService; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import org.junit.Assert; -import org.junit.jupiter.api.Test; -import org.mockito.MockedStatic; -import org.mockito.Mockito; - -class FieldValueCorrectorTest { - - - @Test - void corrector() { - - MockedStatic mockContextUtils = Mockito.mockStatic(ContextUtils.class); - - SchemaService mockSchemaService = Mockito.mock(SchemaService.class); - - SemanticSchema mockSemanticSchema = Mockito.mock(SemanticSchema.class); - - List dimensions = new ArrayList<>(); - List schemaValueMaps = new ArrayList<>(); - SchemaValueMap value1 = new SchemaValueMap(); - value1.setBizName("杰伦"); - value1.setTechName("周杰伦"); - value1.setAlias(Arrays.asList("周杰倫", "Jay Chou", "周董", "周先生")); - schemaValueMaps.add(value1); - - SchemaElement schemaElement = SchemaElement.builder() - .bizName("singer_name") - .name("歌手名") - .model(2L) - .schemaValueMaps(schemaValueMaps) - .build(); - dimensions.add(schemaElement); - - when(mockSemanticSchema.getDimensions()).thenReturn(dimensions); - when(mockSchemaService.getSemanticSchema()).thenReturn(mockSemanticSchema); - mockContextUtils.when(() -> ContextUtils.getBean(SchemaService.class)).thenReturn(mockSchemaService); - - SemanticParseInfo parseInfo = new SemanticParseInfo(); - SchemaElement model = new SchemaElement(); - model.setId(2L); - parseInfo.setModel(model); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select count(song_name) from 歌曲库 where singer_name = '周先生'") - .parseInfo(parseInfo) - .build(); - - FieldValueCorrector corrector = new FieldValueCorrector(); - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql()); - - semanticCorrectInfo.setSql("select count(song_name) from 歌曲库 where singer_name = '杰伦'"); - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals("SELECT count(song_name) FROM 歌曲库 WHERE singer_name = '周杰伦'", semanticCorrectInfo.getSql()); - } -} \ No newline at end of file diff --git a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java b/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java deleted file mode 100644 index 39db3935d..000000000 --- a/chat/core/src/test/java/com/tencent/supersonic/chat/corrector/SelectFieldAppendCorrectorTest.java +++ /dev/null @@ -1,46 +0,0 @@ -package com.tencent.supersonic.chat.corrector; - -import com.tencent.supersonic.chat.api.pojo.SemanticCorrectInfo; -import org.junit.Assert; -import org.junit.jupiter.api.Test; - -class SelectFieldAppendCorrectorTest { - - @Test - void corrector() { - SelectFieldAppendCorrector corrector = new SelectFieldAppendCorrector(); - SemanticCorrectInfo semanticCorrectInfo = SemanticCorrectInfo.builder() - .sql("select 歌曲名 from 歌曲库 where datediff('day', 发布日期, '2023-08-09') <= 1 and 歌手名 = '邓紫棋' " - + "and sys_imp_date = '2023-08-09' and 歌曲发布时 = '2023-08-01' order by 播放量 desc limit 11") - .build(); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals( - "SELECT 歌曲名, 歌手名, 播放量, 歌曲发布时, 发布日期 FROM 歌曲库 WHERE " - + "datediff('day', 发布日期, '2023-08-09') <= 1 AND 歌手名 = '邓紫棋' " - + "AND sys_imp_date = '2023-08-09' AND 歌曲发布时 = '2023-08-01'" - + " ORDER BY 播放量 DESC LIMIT 11", semanticCorrectInfo.getSql()); - - semanticCorrectInfo.setSql("select 用户名 from 内容库产品 where datediff('day', 数据日期, '2023-09-14') <= 30" - + " group by 用户名 having sum(访问次数) > 2000"); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals( - "SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " - + "datediff('day', 数据日期, '2023-09-14') <= 30 " - + "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql()); - - semanticCorrectInfo.setSql("SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " - + "datediff('day', 数据日期, '2023-09-14') <= 30 " - + "GROUP BY 用户名 HAVING sum(访问次数) > 2000"); - - corrector.correct(semanticCorrectInfo); - - Assert.assertEquals( - "SELECT 用户名, sum(访问次数) FROM 内容库产品 WHERE " - + "datediff('day', 数据日期, '2023-09-14') <= 30 " - + "GROUP BY 用户名 HAVING sum(访问次数) > 2000", semanticCorrectInfo.getSql()); - } -} diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/listener/ApplicationStartedListener.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/listener/ApplicationStartedListener.java index 44973ac28..24bfefba8 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/listener/ApplicationStartedListener.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/listener/ApplicationStartedListener.java @@ -7,8 +7,8 @@ import com.tencent.supersonic.knowledge.service.WordService; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections.CollectionUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.context.event.ApplicationStartedEvent; -import org.springframework.context.ApplicationListener; +import org.springframework.boot.CommandLineRunner; +import org.springframework.core.annotation.Order; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; @@ -17,7 +17,8 @@ import java.util.concurrent.CompletableFuture; @Slf4j @Component -public class ApplicationStartedListener implements ApplicationListener { +@Order(5) +public class ApplicationStartedListener implements CommandLineRunner { @Autowired private KnowledgeService knowledgeService; @@ -27,7 +28,7 @@ public class ApplicationStartedListener implements ApplicationListener> modelSchemaCache = CacheBuilder.newBuilder().expireAfterWrite(10, TimeUnit.SECONDS).build(); - protected ParameterizedTypeReference> structTypeRef = - new ParameterizedTypeReference>() { - }; - @SneakyThrows public List fetchModelSchema(List ids, Boolean cacheEnable) { if (cacheEnable) { diff --git a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java index 81a445afc..902c3c83a 100644 --- a/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java +++ b/chat/knowledge/src/main/java/com/tencent/supersonic/knowledge/semantic/LocalSemanticLayer.java @@ -57,17 +57,13 @@ public class LocalSemanticLayer extends BaseSemanticLayer { } @Override + @SneakyThrows public QueryResultWithSchemaResp queryByDsl(QueryDslReq queryDslReq, User user) { - try { - queryService = ContextUtils.getBean(QueryService.class); - Object object = queryService.queryBySql(queryDslReq, user); - QueryResultWithSchemaResp queryResultWithSchemaResp = JsonUtil.toObject(JsonUtil.toString(object), + queryService = ContextUtils.getBean(QueryService.class); + Object object = queryService.queryBySql(queryDslReq, user); + QueryResultWithSchemaResp queryResultWithSchemaResp = JsonUtil.toObject(JsonUtil.toString(object), QueryResultWithSchemaResp.class); - return queryResultWithSchemaResp; - } catch (Exception e) { - log.info("queryByDsl has an exception:{}", e); - } - return null; + return queryResultWithSchemaResp; } @Override diff --git a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java index 2852be869..b4ce1357f 100644 --- a/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java +++ b/common/src/main/java/com/tencent/supersonic/common/pojo/enums/AggOperatorEnum.java @@ -10,7 +10,7 @@ public enum AggOperatorEnum { SUM("SUM"), - DISTINCT("DISTINCT"), + COUNT_DISTINCT("COUNT_DISTINCT"), TOPN("TOPN"), diff --git a/dev/reformat b/dev/reformat new file mode 100755 index 000000000..fe6afacd5 --- /dev/null +++ b/dev/reformat @@ -0,0 +1,32 @@ +#!/bin/bash +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -x + + +PROFILES="-P " + +# python style checks rely on `black` in path +if ! command -v black &> /dev/null +then + echo "Skip Python lint since 'black' is not available. Please install 'black' by running 'pip install black==22.3.0'" +else + PROFILES="${PROFILES} spotless-python" +fi + +mvn spotless:apply $PROFILES diff --git a/docs/images/agent.png b/docs/images/agent.png deleted file mode 100644 index ffd33379a..000000000 Binary files a/docs/images/agent.png and /dev/null differ diff --git a/docs/images/chat_config.png b/docs/images/chat_config.png deleted file mode 100644 index 760cf0168..000000000 Binary files a/docs/images/chat_config.png and /dev/null differ diff --git a/docs/images/database.png b/docs/images/database.png deleted file mode 100644 index 49313675b..000000000 Binary files a/docs/images/database.png and /dev/null differ diff --git a/docs/images/datasource_base_info.png b/docs/images/datasource_base_info.png deleted file mode 100644 index 7b8797e41..000000000 Binary files a/docs/images/datasource_base_info.png and /dev/null differ diff --git a/docs/images/datasource_create.png b/docs/images/datasource_create.png deleted file mode 100644 index c99bba59d..000000000 Binary files a/docs/images/datasource_create.png and /dev/null differ diff --git a/docs/images/datasource_extend_info.png b/docs/images/datasource_extend_info.png deleted file mode 100644 index dfe83b293..000000000 Binary files a/docs/images/datasource_extend_info.png and /dev/null differ diff --git a/docs/images/datasource_list.png b/docs/images/datasource_list.png deleted file mode 100644 index 10df4f394..000000000 Binary files a/docs/images/datasource_list.png and /dev/null differ diff --git a/docs/images/datasource_sql.png b/docs/images/datasource_sql.png deleted file mode 100644 index aea8be911..000000000 Binary files a/docs/images/datasource_sql.png and /dev/null differ diff --git a/docs/images/detail_default.png b/docs/images/detail_default.png deleted file mode 100644 index 30e337194..000000000 Binary files a/docs/images/detail_default.png and /dev/null differ diff --git a/docs/images/detail_entity.png b/docs/images/detail_entity.png deleted file mode 100644 index 915914526..000000000 Binary files a/docs/images/detail_entity.png and /dev/null differ diff --git a/docs/images/dimension_create.png b/docs/images/dimension_create.png deleted file mode 100644 index 198273365..000000000 Binary files a/docs/images/dimension_create.png and /dev/null differ diff --git a/docs/images/dimension_list.png b/docs/images/dimension_list.png deleted file mode 100644 index 69b42e4d7..000000000 Binary files a/docs/images/dimension_list.png and /dev/null differ diff --git a/docs/images/domain.png b/docs/images/domain.png deleted file mode 100644 index cfc4cc0cd..000000000 Binary files a/docs/images/domain.png and /dev/null differ diff --git a/docs/images/metric_base_info.png b/docs/images/metric_base_info.png deleted file mode 100644 index 33968b331..000000000 Binary files a/docs/images/metric_base_info.png and /dev/null differ diff --git a/docs/images/metric_default.png b/docs/images/metric_default.png deleted file mode 100644 index f43959026..000000000 Binary files a/docs/images/metric_default.png and /dev/null differ diff --git a/docs/images/metric_list.png b/docs/images/metric_list.png deleted file mode 100644 index b87a5de8b..000000000 Binary files a/docs/images/metric_list.png and /dev/null differ diff --git a/docs/images/metric_sql_info.png b/docs/images/metric_sql_info.png deleted file mode 100644 index 350ff8df0..000000000 Binary files a/docs/images/metric_sql_info.png and /dev/null differ diff --git a/docs/images/model.png b/docs/images/model.png deleted file mode 100644 index cf08a4aa5..000000000 Binary files a/docs/images/model.png and /dev/null differ diff --git a/docs/images/nlp_config.png b/docs/images/nlp_config.png deleted file mode 100644 index e8e34d284..000000000 Binary files a/docs/images/nlp_config.png and /dev/null differ diff --git a/docs/images/plugin.png b/docs/images/plugin.png deleted file mode 100644 index 3fd0e0b84..000000000 Binary files a/docs/images/plugin.png and /dev/null differ diff --git a/docs/images/text2sql_config.png b/docs/images/text2sql_config.png deleted file mode 100644 index d9f641438..000000000 Binary files a/docs/images/text2sql_config.png and /dev/null differ diff --git a/docs/images/visibility_dim_value.png b/docs/images/visibility_dim_value.png deleted file mode 100644 index 34a18792b..000000000 Binary files a/docs/images/visibility_dim_value.png and /dev/null differ diff --git a/docs/images/visibility_dim_value_show.png b/docs/images/visibility_dim_value_show.png deleted file mode 100644 index 84d69a56a..000000000 Binary files a/docs/images/visibility_dim_value_show.png and /dev/null differ diff --git a/docs/images/visibility_item.png b/docs/images/visibility_item.png deleted file mode 100644 index 7e9d8acf5..000000000 Binary files a/docs/images/visibility_item.png and /dev/null differ diff --git a/docs/images/wechat_contact.jpeg b/docs/images/wechat_contact.jpeg index 3879aba3d..b7a0be7b7 100644 Binary files a/docs/images/wechat_contact.jpeg and b/docs/images/wechat_contact.jpeg differ diff --git a/docs/userguides/llm_config_cn.md b/docs/userguides/llm_config_cn.md deleted file mode 100644 index 34ef26747..000000000 --- a/docs/userguides/llm_config_cn.md +++ /dev/null @@ -1,26 +0,0 @@ -# LLM模型配置 - -### **简介** - -语言模型的使用是超音数的重要一环。能显著增强对用户的问题的理解能力,是通过对话形式与用户交互的基石之一。在本项目中对语言模型能力的应用主要在 LLM 和 Embedding 两方面;默认使用的模型中,LLM选用闭源模型 gpt-3.5-turbo-16k,Embedding模型选用开源模型 GanymedeNil/text2vec-large-chinese。用户可以根据自己实际需求进行配置更改。 - - -### **配置方式** -
- -

图1-1 LLM配置文件

-
- -1. LLM模型相关的配置,在 supersonic/chat/core/src/main/python/llm/run_config.py 进行配置。 -2. LLM采用OpenAI的闭源模型 gpt-3.5-turbo-16k,在使用时需要提供OpenAI的API-Key才能调用LLM模型,通过 OPENAI_API_KEY 变量进行配置。 -3. Embedding模型采用开源模型 GanymedeNil/text2vec-large-chinese,通过 HF_TEXT2VEC_MODEL_NAME 变量进行位置,为了使用方便采用托管在HuggingFace的源,初次启动时自动下载模型文件。 - -### **FAQ** -1. 可以用开源的LLM模型替代OpenAI的GPT模型吗? - - 暂时不能。我们测试过大部分主流的开源LLM,在实际使用中,在本项目需要LLM提供的逻辑推理和代码生成场景上,开源模型还不能满足需求。 - - 我们会持续跟进开源LLM的最新进展,在有满足要求的开源LLM后,在项目中集成私有化部署开源LLM的能力。 -2. GPT4、GPT3.5、GPT3.5-16k 这几个模型用哪个比较好? - - GPT3.5、GPT3.5-16k 均能基本满足要求,但会有输出结果不稳定的情况;GPT3.5的token长度限制为4k,在现有CoT策略下,容易出现超过长度限制的情况。 - - GPT4的输出更稳定,但费用成本远超GPT3.5,可以根据实际使用场景进行选择。 -3. Embedding模型用其他的可以吗? - - 可以。可以以该项目[text2vec]([URL](https://github.com/shibing624/text2vec))的榜单作为参考,然后在HuggingFace找到对应模型的model card,修改HF_TEXT2VEC_MODEL_NAME变量的取值。 diff --git a/docs/userguides/text2sql_cn.md b/docs/userguides/text2sql_cn.md deleted file mode 100644 index eb207271d..000000000 --- a/docs/userguides/text2sql_cn.md +++ /dev/null @@ -1,29 +0,0 @@ -# text2sql功能相关配置 - -### **简介** -text2sql的功能实现,高度依赖对LLM的应用。通过LLM生成SQL的过程中,利用小样本(few-shots-examples)通过思维链(chain-of-thoughts)的方式对LLM in-context-learning的能力进行引导,对于生成较为稳定且符合下游语法解析规则的SQL非常重要。用户可以根据自身需要,对样本池及样本的数量进行配置,使其更加符合自身业务特点。 - -### **配置方式** -1. 样本池的配置。 - - supersonic/chat/core/src/main/python/few_shot_example/sql_exampler.py 为样本池配置文件。用户可以以已有的样本作为参考,配置更贴近自身业务需求的样本,用于更好的引导LLM生成SQL。 -2. 样本数量的配置。 - - 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_FEW_SHOTS_EXAMPLE_NUM 变量进行配置。 - - 默认值为15,为项目在内部实践后较优的经验值。样本少太少,对导致LLM在生成SQL的过程中缺少引导和示范,生成的SQL会更不稳定;样本太多,会增加生成SQL需要的时间和LLM的token消耗(或超过LLM的token上限)。 -3. SQL生成方式的配置 - - 在 supersonic/chat/core/src/main/python/run_config.py 中通过 TEXT2DSL_IS_SHORTCUT 变量进行配置。 - - 默认值为False;当为False时,会调用2次LLM生成SQL;当为True时,会只调用1次LLM生成SQL。相较于2次LLM调用生成的SQL,耗时会减少30-40%,token的消耗量会减少30%左右,但生成的SQL正确率会有所下降。 -
- -

图1-1 配置文件

-
- -### **运行中更新配置的脚本** -1. 如果在启动项目后,用户需要对text2sql功能的相关配置进行调试,可以在修改相关配置文件后,通过以下2种方式让配置在项目运行中让配置生效。 - - 执行 supersonic-daemon.sh reload llmparser - - 执行 python examples_reload_run.py -### **FAQ** -1. 生成一个SQL需要消耗的的LLM token数量太多了,按照openAI对token的收费标准,生成一个SQL太贵了,可以少用一些token吗? - - 可以。 用户可以根据自身需求,如配置方式1.中所示,修改样本池中的样本,选用一些更加简短的样本。如配置方式2.中所示,减少使用的样本数量。配置方式3.中所示,只调用1次LLM生成SQL。 - - 需要注意,样本和样本数量的选择对生成SQL的质量有很大的影响。过于激进的降低输入的token数量可能会降低生成SQL的质量。需要用户根据自身业务特点实测后进行平衡。 - - diff --git a/launchers/chat/src/main/resources/META-INF/spring.factories b/launchers/chat/src/main/resources/META-INF/spring.factories index c760ce104..aa0a5ff20 100644 --- a/launchers/chat/src/main/resources/META-INF/spring.factories +++ b/launchers/chat/src/main/resources/META-INF/spring.factories @@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.chat.api.component.SemanticCorrector=\ - com.tencent.supersonic.chat.corrector.DateFieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \ - com.tencent.supersonic.chat.corrector.FieldNameCorrector, \ - com.tencent.supersonic.chat.corrector.FieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionCorrector, \ - com.tencent.supersonic.chat.corrector.TableNameCorrector, \ - com.tencent.supersonic.chat.corrector.QueryFilterAppend, \ - com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \ - com.tencent.supersonic.chat.corrector.FieldValueCorrector \ No newline at end of file + com.tencent.supersonic.chat.corrector.GlobalCorrector, \ + com.tencent.supersonic.chat.corrector.TableCorrector, \ + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.SelectCorrector, \ + com.tencent.supersonic.chat.corrector.WhereCorrector, \ + com.tencent.supersonic.chat.corrector.HavingCorrector \ No newline at end of file diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/LoadBenchMarkDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/LoadBenchMarkDemo.java new file mode 100644 index 000000000..7e3de97f0 --- /dev/null +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/LoadBenchMarkDemo.java @@ -0,0 +1,199 @@ +package com.tencent.supersonic; + +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.semantic.api.model.enums.DimensionTypeEnum; +import com.tencent.supersonic.semantic.api.model.enums.IdentifyTypeEnum; +import com.tencent.supersonic.semantic.api.model.pojo.Dim; +import com.tencent.supersonic.semantic.api.model.pojo.DimensionTimeTypeParams; +import com.tencent.supersonic.semantic.api.model.pojo.Identify; +import com.tencent.supersonic.semantic.api.model.pojo.Measure; +import com.tencent.supersonic.semantic.api.model.request.DatasourceReq; +import com.tencent.supersonic.semantic.api.model.request.DomainReq; +import com.tencent.supersonic.semantic.api.model.request.ModelReq; +import com.tencent.supersonic.semantic.model.domain.DatasourceService; +import com.tencent.supersonic.semantic.model.domain.DomainService; +import com.tencent.supersonic.semantic.model.domain.ModelService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.CommandLineRunner; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@Component +@Slf4j +@Order(2) +public class LoadBenchMarkDemo implements CommandLineRunner { + + private User user = User.getFakeUser(); + + @Value("${spring.h2.demo.enabled:false}") + private boolean demoEnable; + + @Autowired + private DomainService domainService; + @Autowired + private ModelService modelService; + @Autowired + private DatasourceService datasourceService; + + @Override + public void run(String... args) { + if (!demoEnable) { + return; + } + try { + addDomain(); + addModel_1(); + addDatasource_1(); + addDatasource_2(); + addDatasource_3(); + addDatasource_4(); + } catch (Exception e) { + log.error("Failed to add bench mark demo data", e); + } + + } + + public void addDomain() { + DomainReq domainReq = new DomainReq(); + domainReq.setName("测评数据-音乐"); + domainReq.setBizName("music"); + domainReq.setParentId(0L); + domainReq.setViewers(Arrays.asList("admin", "tom", "jack")); + domainReq.setViewOrgs(Collections.singletonList("admin")); + domainReq.setAdmins(Collections.singletonList("admin")); + domainReq.setAdminOrgs(Collections.emptyList()); + domainService.createDomain(domainReq, user); + } + + public void addModel_1() { + ModelReq modelReq = new ModelReq(); + modelReq.setName("测评数据-音乐"); + modelReq.setBizName("music"); + modelReq.setDomainId(2L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("admin")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + modelService.createModel(modelReq, user); + } + + public void addDatasource_1() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setModelId(3L); + datasourceReq.setName("艺术类型"); + datasourceReq.setBizName("genre"); + datasourceReq.setDescription("艺术类型"); + datasourceReq.setDatabaseId(1L); + + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionTypeEnum.time.name(), 0); + dimension1.setTypeParams(new DimensionTimeTypeParams()); + dimensions.add(dimension1); + dimensions.add(new Dim("活跃区域", "most_popular_in", DimensionTypeEnum.categorical.name(), 1)); + datasourceReq.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("音乐类型名称", IdentifyTypeEnum.primary.name(), "g_name")); + datasourceReq.setIdentifiers(identifiers); + + List measures = new ArrayList<>(); + Measure measure = new Measure("评分", "rating", AggOperatorEnum.SUM.name(), 0); + measures.add(measure); + datasourceReq.setMeasures(measures); + + datasourceReq.setQueryType("sql_query"); + datasourceReq.setSqlQuery("SELECT g_name, rating, most_popular_in FROM genre"); + datasourceService.createDatasource(datasourceReq, user); + } + + public void addDatasource_2() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setModelId(3L); + datasourceReq.setName("艺术家"); + datasourceReq.setBizName("artist"); + datasourceReq.setDescription("艺术家"); + datasourceReq.setDatabaseId(1L); + + List dimensions = new ArrayList<>(); + dimensions.add(new Dim("国籍", "country", DimensionTypeEnum.categorical.name(), 1)); + dimensions.add(new Dim("性别", "gender", DimensionTypeEnum.categorical.name(), 1)); + datasourceReq.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("艺术家名称", IdentifyTypeEnum.primary.name(), "artist_name")); + identifiers.add(new Identify("音乐类型名称", IdentifyTypeEnum.foreign.name(), "g_name")); + datasourceReq.setIdentifiers(identifiers); + + datasourceReq.setMeasures(Collections.emptyList()); + + datasourceReq.setQueryType("sql_query"); + datasourceReq.setSqlQuery("SELECT artist_name, country, gender, g_name FROM artist"); + datasourceService.createDatasource(datasourceReq, user); + } + + public void addDatasource_3() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setModelId(3L); + datasourceReq.setName("文件"); + datasourceReq.setBizName("files"); + datasourceReq.setDescription("文件"); + datasourceReq.setDatabaseId(1L); + + List dimensions = new ArrayList<>(); + dimensions.add(new Dim("持续时间", "duration", DimensionTypeEnum.categorical.name(), 1)); + dimensions.add(new Dim("文件格式", "formats", DimensionTypeEnum.categorical.name(), 1)); + datasourceReq.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("歌曲ID", IdentifyTypeEnum.primary.name(), "f_id")); + identifiers.add(new Identify("艺术家名称", IdentifyTypeEnum.foreign.name(), "artist_name")); + datasourceReq.setIdentifiers(identifiers); + + datasourceReq.setMeasures(Collections.emptyList()); + + datasourceReq.setQueryType("sql_query"); + datasourceReq.setSqlQuery("SELECT f_id, artist_name, file_size, duration, formats FROM files"); + datasourceService.createDatasource(datasourceReq, user); + } + + public void addDatasource_4() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setModelId(3L); + datasourceReq.setName("歌曲"); + datasourceReq.setBizName("song"); + datasourceReq.setDescription("歌曲"); + datasourceReq.setDatabaseId(1L); + + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionTypeEnum.time.name(), 0); + dimension1.setTypeParams(new DimensionTimeTypeParams()); + dimensions.add(dimension1); + dimensions.add(new Dim("国家", "country", DimensionTypeEnum.categorical.name(), 1)); + dimensions.add(new Dim("语种", "languages", DimensionTypeEnum.categorical.name(), 1)); + dimensions.add(new Dim("发行时间", "releasedate", DimensionTypeEnum.categorical.name(), 1)); + datasourceReq.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("歌曲名称", IdentifyTypeEnum.primary.name(), "song_name")); + identifiers.add(new Identify("歌曲ID", IdentifyTypeEnum.foreign.name(), "f_id")); + datasourceReq.setIdentifiers(identifiers); + + List measures = new ArrayList<>(); + measures.add(new Measure("分辨率", "resolution", AggOperatorEnum.SUM.name(), 1)); + measures.add(new Measure("评分", "rating", AggOperatorEnum.SUM.name(), 1)); + datasourceReq.setMeasures(measures); + + datasourceReq.setQueryType("sql_query"); + datasourceReq.setSqlQuery("SELECT imp_date, song_name, artist_name, country, f_id, g_name, " + + " rating, languages, releasedate, resolution FROM song"); + datasourceService.createDatasource(datasourceReq, user); + } + +} \ No newline at end of file diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/LoadModelDataDemo.java b/launchers/standalone/src/main/java/com/tencent/supersonic/LoadModelDataDemo.java new file mode 100644 index 000000000..9c8fa271d --- /dev/null +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/LoadModelDataDemo.java @@ -0,0 +1,334 @@ +package com.tencent.supersonic; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthGroup; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthRule; +import com.tencent.supersonic.auth.api.authorization.service.AuthService; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum; +import com.tencent.supersonic.common.pojo.enums.SensitiveLevelEnum; +import com.tencent.supersonic.semantic.api.model.enums.DimensionTypeEnum; +import com.tencent.supersonic.semantic.api.model.enums.IdentifyTypeEnum; +import com.tencent.supersonic.semantic.api.model.enums.SemanticTypeEnum; +import com.tencent.supersonic.semantic.api.model.pojo.Dim; +import com.tencent.supersonic.semantic.api.model.pojo.DimensionTimeTypeParams; +import com.tencent.supersonic.semantic.api.model.pojo.Entity; +import com.tencent.supersonic.semantic.api.model.pojo.Identify; +import com.tencent.supersonic.semantic.api.model.pojo.Measure; +import com.tencent.supersonic.semantic.api.model.pojo.MetricTypeParams; +import com.tencent.supersonic.semantic.api.model.request.DatabaseReq; +import com.tencent.supersonic.semantic.api.model.request.DatasourceReq; +import com.tencent.supersonic.semantic.api.model.request.DimensionReq; +import com.tencent.supersonic.semantic.api.model.request.DomainReq; +import com.tencent.supersonic.semantic.api.model.request.MetricReq; +import com.tencent.supersonic.semantic.api.model.request.ModelReq; +import com.tencent.supersonic.semantic.model.domain.DatabaseService; +import com.tencent.supersonic.semantic.model.domain.DatasourceService; +import com.tencent.supersonic.semantic.model.domain.DimensionService; +import com.tencent.supersonic.semantic.model.domain.DomainService; +import com.tencent.supersonic.semantic.model.domain.MetricService; +import com.tencent.supersonic.semantic.model.domain.ModelService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.CommandLineRunner; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@Component +@Slf4j +@Order(1) +public class LoadModelDataDemo implements CommandLineRunner { + + private User user = User.getFakeUser(); + + @Value("${spring.h2.demo.enabled:false}") + private boolean demoEnable; + + @Autowired + private DatabaseService databaseService; + @Autowired + private DomainService domainService; + @Autowired + private ModelService modelService; + @Autowired + private DatasourceService datasourceService; + @Autowired + private DimensionService dimensionService; + @Autowired + private MetricService metricService; + @Autowired + private AuthService authService; + + @Override + public void run(String... args) { + if (!demoEnable) { + return; + } + try { + addDatabase(); + addDomain(); + addModel_1(); + addDatasource_1(); + addDatasource_2(); + addDatasource_3(); + addModel_2(); + addDatasource_4(); + updateDimension(); + updateMetric(); + addAuthGroup_1(); + addAuthGroup_2(); + } catch (Exception e) { + log.error("Failed to add model demo data", e); + } + + } + + public void addDatabase() { + DatabaseReq databaseReq = new DatabaseReq(); + databaseReq.setName("H2数据实例"); + databaseReq.setDescription("样例数据库实例"); + databaseReq.setType("h2"); + databaseReq.setUrl("jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false"); + databaseReq.setUsername("root"); + databaseReq.setPassword("semantic"); + databaseService.createOrUpdateDatabase(databaseReq, user); + } + + public void addDomain() { + DomainReq domainReq = new DomainReq(); + domainReq.setName("超音数"); + domainReq.setBizName("supersonic"); + domainReq.setParentId(0L); + domainReq.setViewers(Arrays.asList("admin", "tom", "jack")); + domainReq.setViewOrgs(Collections.singletonList("admin")); + domainReq.setAdmins(Collections.singletonList("admin")); + domainReq.setAdminOrgs(Collections.emptyList()); + domainService.createDomain(domainReq, user); + } + + public void addModel_1() { + ModelReq modelReq = new ModelReq(); + modelReq.setName("超音数"); + modelReq.setBizName("supersonic"); + modelReq.setDomainId(1L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("admin")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + modelService.createModel(modelReq, user); + } + + public void addDatasource_1() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setName("用户部门"); + datasourceReq.setBizName("user_department"); + datasourceReq.setDescription("用户部门"); + datasourceReq.setDatabaseId(1L); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("用户名", IdentifyTypeEnum.primary.name(), "user_name")); + datasourceReq.setIdentifiers(identifiers); + + List dimensions = new ArrayList<>(); + dimensions.add(new Dim("部门", "department", + DimensionTypeEnum.categorical.name(), 1)); + datasourceReq.setDimensions(dimensions); + + datasourceReq.setMeasures(Collections.emptyList()); + datasourceReq.setQueryType("table_query"); + datasourceReq.setTableQuery("PUBLIC.s2_user_department"); + datasourceReq.setModelId(1L); + datasourceService.createDatasource(datasourceReq, user); + } + + public void addDatasource_2() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setName("PVUV统计"); + datasourceReq.setBizName("s2_pv_uv_statis"); + datasourceReq.setDescription("PVUV统计"); + datasourceReq.setDatabaseId(1L); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("用户名", IdentifyTypeEnum.primary.name(), "user_name")); + datasourceReq.setIdentifiers(identifiers); + + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionTypeEnum.time.name(), 0); + dimension1.setTypeParams(new DimensionTimeTypeParams()); + dimensions.add(dimension1); + Dim dimension2 = new Dim("", "page", DimensionTypeEnum.categorical.name(), 0); + dimensions.add(dimension2); + datasourceReq.setDimensions(dimensions); + + List measures = new ArrayList<>(); + Measure measure1 = new Measure("访问次数", "pv", AggOperatorEnum.SUM.name(), 1); + measures.add(measure1); + + Measure measure2 = new Measure("访问人数", "uv", AggOperatorEnum.COUNT_DISTINCT.name(), 1); + measures.add(measure2); + + datasourceReq.setMeasures(measures); + datasourceReq.setSqlQuery("SELECT imp_date, user_name, page, 1 as pv, user_name as uv FROM s2_pv_uv_statis"); + datasourceReq.setQueryType("sql_query"); + datasourceReq.setModelId(1L); + datasourceService.createDatasource(datasourceReq, user); + } + + public void addDatasource_3() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setName("停留时长统计"); + datasourceReq.setBizName("s2_stay_time_statis"); + datasourceReq.setDescription("停留时长统计"); + datasourceReq.setDatabaseId(1L); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("用户名", IdentifyTypeEnum.primary.name(), "user_name")); + datasourceReq.setIdentifiers(identifiers); + + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionTypeEnum.time.name(), 0); + dimension1.setTypeParams(new DimensionTimeTypeParams()); + dimensions.add(dimension1); + Dim dimension2 = new Dim("页面", "page", DimensionTypeEnum.categorical.name(), 1); + dimensions.add(dimension2); + datasourceReq.setDimensions(dimensions); + + List measures = new ArrayList<>(); + Measure measure1 = new Measure("停留时长", "stay_hours", AggregateTypeEnum.SUM.name(), 1); + measures.add(measure1); + + datasourceReq.setMeasures(measures); + datasourceReq.setTableQuery("PUBLIC.s2_stay_time_statis"); + datasourceReq.setQueryType("table_query"); + datasourceReq.setModelId(1L); + datasourceService.createDatasource(datasourceReq, user); + } + + public void addModel_2() { + ModelReq modelReq = new ModelReq(); + modelReq.setName("艺人库"); + modelReq.setBizName("singer"); + modelReq.setDomainId(1L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("admin")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + modelReq.setEntity(new Entity(7L, Arrays.asList("歌手", "艺人"))); + modelService.createModel(modelReq, user); + } + + public void addDatasource_4() throws Exception { + DatasourceReq datasourceReq = new DatasourceReq(); + datasourceReq.setName("艺人库"); + datasourceReq.setBizName("singer"); + datasourceReq.setDescription("艺人库"); + datasourceReq.setDatabaseId(1L); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("歌手名", IdentifyTypeEnum.primary.name(), "singer_name")); + datasourceReq.setIdentifiers(identifiers); + + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionTypeEnum.time.name(), 0); + dimension1.setTypeParams(new DimensionTimeTypeParams()); + dimensions.add(dimension1); + dimensions.add(new Dim("活跃区域", "act_area", + DimensionTypeEnum.categorical.name(), 1)); + dimensions.add(new Dim("代表作", "song_name", + DimensionTypeEnum.categorical.name(), 1)); + dimensions.add(new Dim("风格", "genre", + DimensionTypeEnum.categorical.name(), 1)); + datasourceReq.setDimensions(dimensions); + + Measure measure1 = new Measure("播放量", "js_play_cnt", "sum", 1); + Measure measure2 = new Measure("下载量", "down_cnt", "sum", 1); + Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1); + datasourceReq.setMeasures(Lists.newArrayList(measure1, measure2, measure3)); + datasourceReq.setQueryType("table_query"); + datasourceReq.setTableQuery("PUBLIC.singer"); + datasourceReq.setModelId(2L); + datasourceService.createDatasource(datasourceReq, user); + } + + public void updateDimension() throws Exception { + DimensionReq dimensionReq = new DimensionReq(); + dimensionReq.setModelId(1L); + dimensionReq.setType(DimensionTypeEnum.categorical.name()); + dimensionReq.setId(3L); + dimensionReq.setName("页面"); + dimensionReq.setBizName("page"); + dimensionReq.setDatasourceId(3L); + dimensionReq.setAlias("page"); + dimensionReq.setSemanticType(SemanticTypeEnum.CATEGORY.name()); + dimensionReq.setSensitiveLevel(2); + dimensionReq.setDescription("页面"); + dimensionReq.setExpr("page"); + dimensionReq.setDimValueMaps(Collections.emptyList()); + dimensionService.updateDimension(dimensionReq, user); + } + + public void updateMetric() throws Exception { + MetricReq metricReq = new MetricReq(); + metricReq.setModelId(1L); + metricReq.setId(3L); + metricReq.setName("停留时长"); + metricReq.setBizName("stay_hours"); + metricReq.setSensitiveLevel(SensitiveLevelEnum.HIGH.getCode()); + metricReq.setDescription("停留时长"); + metricReq.setTags(Collections.singletonList("核心指标")); + metricReq.setAlias("访问时长"); + MetricTypeParams metricTypeParams = new MetricTypeParams(); + metricTypeParams.setExpr("s2_stay_time_statis_stay_hours"); + List measures = new ArrayList<>(); + Measure measure = new Measure("停留时长", + "s2_stay_time_statis_stay_hours", AggOperatorEnum.SUM.getOperator(), 1); + measure.setDatasourceId(3L); + measures.add(measure); + metricTypeParams.setMeasures(measures); + metricReq.setTypeParams(metricTypeParams); + metricService.updateExprMetric(metricReq, user); + } + + public void addAuthGroup_1() { + AuthGroup authGroupReq = new AuthGroup(); + authGroupReq.setModelId("1"); + authGroupReq.setName("admin-permission"); + + List authRules = new ArrayList<>(); + AuthRule authRule = new AuthRule(); + authRule.setMetrics(Collections.singletonList("stay_hours")); + authRule.setDimensions(Collections.singletonList("page")); + authRules.add(authRule); + + authGroupReq.setAuthRules(authRules); + authGroupReq.setAuthorizedUsers(Collections.singletonList("jack")); + authGroupReq.setAuthorizedDepartmentIds(Collections.emptyList()); + authService.addOrUpdateAuthGroup(authGroupReq); + } + + public void addAuthGroup_2() { + AuthGroup authGroupReq = new AuthGroup(); + authGroupReq.setModelId("1"); + authGroupReq.setName("tom_sales_permission"); + + List authRules = new ArrayList<>(); + AuthRule authRule = new AuthRule(); + authRule.setMetrics(Collections.singletonList("stay_hours")); + authRule.setDimensions(Collections.singletonList("page")); + authRules.add(authRule); + + authGroupReq.setAuthRules(authRules); + authGroupReq.setDimensionFilters(Collections.singletonList("department in ('sales')")); + authGroupReq.setDimensionFilterDescription("部门 in [sales]"); + authGroupReq.setAuthorizedUsers(Collections.singletonList("tom")); + authGroupReq.setAuthorizedDepartmentIds(Collections.emptyList()); + authService.addOrUpdateAuthGroup(authGroupReq); + } + +} \ No newline at end of file diff --git a/launchers/standalone/src/main/resources/META-INF/spring.factories b/launchers/standalone/src/main/resources/META-INF/spring.factories index c0adbf1d6..3e8c7a5d6 100644 --- a/launchers/standalone/src/main/resources/META-INF/spring.factories +++ b/launchers/standalone/src/main/resources/META-INF/spring.factories @@ -31,12 +31,9 @@ com.tencent.supersonic.auth.api.authentication.adaptor.UserAdaptor=\ com.tencent.supersonic.auth.authentication.adaptor.DefaultUserAdaptor com.tencent.supersonic.chat.api.component.SemanticCorrector=\ - com.tencent.supersonic.chat.corrector.DateFieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionAliasCorrector, \ - com.tencent.supersonic.chat.corrector.FieldNameCorrector, \ - com.tencent.supersonic.chat.corrector.FieldCorrector, \ - com.tencent.supersonic.chat.corrector.FunctionCorrector, \ - com.tencent.supersonic.chat.corrector.TableNameCorrector, \ - com.tencent.supersonic.chat.corrector.QueryFilterAppend, \ - com.tencent.supersonic.chat.corrector.SelectFieldAppendCorrector, \ - com.tencent.supersonic.chat.corrector.FieldValueCorrector + com.tencent.supersonic.chat.corrector.GlobalCorrector, \ + com.tencent.supersonic.chat.corrector.TableCorrector, \ + com.tencent.supersonic.chat.corrector.GroupByCorrector, \ + com.tencent.supersonic.chat.corrector.SelectCorrector, \ + com.tencent.supersonic.chat.corrector.WhereCorrector, \ + com.tencent.supersonic.chat.corrector.HavingCorrector diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt new file mode 100644 index 000000000..2c86bfce4 --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/benchmark_cspider.txt @@ -0,0 +1,31 @@ +孟加拉国 _3_8 9000 +锡尔赫特、吉大港、库斯蒂亚 _3_8 9000 +加拿大 _3_8 9000 +美国 _3_8 9000 +tagore _3_9 9000 +nazrul _3_9 9000 +民间 _3_9 9000 +现代 _3_9 9000 +蓝调 _3_9 9000 +流行 _3_9 9000 +孟加拉国 _3_10 9000 +印度 _3_10 9000 +美国 _3_10 9000 +英国 _3_10 9000 +男性 _3_11 9000 +女性 _3_11 9000 +Shrikanta _3_12 9000 +Prity _3_12 9000 +Farida _3_12 9000 +Topu _3_12 9000 +Enrique _3_12 9000 +Michel _3_12 9000 +mp4 _3_14 9000 +mp3 _3_14 9000 +孟加拉语 _3_16 9000 +英文 _3_16 9000 +Tumi#长袍#尼罗布 _3_18 9000 +舒克诺#帕塔尔#努普尔#帕埃 _3_18 9000 +阿米·奥帕尔·霍伊 _3_18 9000 +我的爱 _3_18 9000 +打败它 _3_18 9000 diff --git a/launchers/standalone/src/main/resources/db/data-h2.sql b/launchers/standalone/src/main/resources/db/data-h2.sql index 68709c0c2..eb8387117 100644 --- a/launchers/standalone/src/main/resources/db/data-h2.sql +++ b/launchers/standalone/src/main/resources/db/data-h2.sql @@ -5,32 +5,6 @@ insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom' insert into s2_user (id, `name`, password, display_name, email, is_admin) values (4, 'lucy','123456','lucy','lucy@xx.com', 1); insert into s2_user (id, `name`, password, display_name, email) values (5, 'alice','123456','alice','alice@xx.com'); --- sample models -insert into s2_domain (id, `name`, biz_name, parent_id, status, created_at, created_by, updated_at, updated_by, `admin`, admin_org, viewer, view_org) VALUES(1, '超音数', 'supersonic', 0, 1, '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin', '', 'admin,tom,jack', 'admin' ); -insert into s2_model (id, `name`, biz_name, domain_id, created_at, created_by, updated_at, updated_by, `admin`, admin_org, is_open, viewer, view_org, entity) VALUES(1, '超音数', 'supersonic', 1, '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin', '', 0, 'admin,tom,jack', 'admin','' ); -insert into s2_model (id, `name`, biz_name, domain_id, created_at, created_by, updated_at, updated_by, `admin`, admin_org, is_open, viewer, view_org, entity) VALUES(2, '艺人库', 'singer', 1, '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin', '', 0, 'admin,tom,jack', 'admin','{"entityId": 7, "names": ["歌手", "艺人"]}' ); -insert into s2_database (id, `name`, description, `type` ,config ,created_at ,created_by ,updated_at ,updated_by, `admin`) VALUES(1, 'H2数据实例', '', 'h2', '{"password":"semantic","url":"jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false","userName":"root"}', '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(1, 1, '停留时长统计', 's2_stay_time_statis', '停留时长统计', 1, '{"dimensions":[{"bizName":"imp_date","dateFormat":"yyyy-MM-dd","expr":"imp_date","isCreateDimension":0,"type":"time","typeParams":{"isPrimary":"true","timeGranularity":"day"}},{"bizName":"page","dateFormat":"yyyy-MM-dd","expr":"page","isCreateDimension":0,"type":"categorical"}],"identifiers":[{"bizName":"user_name","name":"用户名","type":"primary"}],"measures":[{"agg":"sum","bizName":"s2_stay_time_statis_stay_hours","expr":"stay_hours","isCreateMetric":1,"name":"停留时长"}],"queryType":"sql_query","sqlQuery":"SELECT imp_date, page,user_name,stay_hours FROM s2_stay_time_statis"}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(2, 1, 'PVUV统计', 's2_pv_uv_statis', 'PVUV统计', 1, '{"dimensions":[{"bizName":"imp_date","dateFormat":"yyyy-MM-dd","expr":"imp_date","isCreateDimension":0,"type":"time","typeParams":{"isPrimary":"true","timeGranularity":"day"}},{"bizName":"page","dateFormat":"yyyy-MM-dd","expr":"page","isCreateDimension":0,"type":"categorical"}],"identifiers":[{"bizName":"user_name","name":"用户名","type":"primary"}],"measures":[{"agg":"sum","bizName":"s2_pv_uv_statis_pv","expr":"pv","isCreateMetric":1,"name":"访问次数"},{"agg":"count_distinct","bizName":"s2_pv_uv_statis_uv","expr":"uv","isCreateMetric":1,"name":"访问人数"}],"queryType":"sql_query","sqlQuery":"SELECT imp_date, user_name,page,1 as pv, user_name as uv FROM s2_pv_uv_statis"}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(3, 1, '用户部门', 'user_department', '用户部门', 1, '{"dimensions":[{"bizName":"department","dateFormat":"yyyy-MM-dd","expr":"department","isCreateDimension":1,"name":"部门","type":"categorical"}],"identifiers":[{"bizName":"user_name","name":"用户名","type":"primary"}],"measures":[],"queryType":"sql_query","sqlQuery":"SELECT user_name,department FROM s2_user_department"}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(4, 2, '艺人库', 'singer', '艺人库', 1, '{"dimensions":[{"bizName":"imp_date","dateFormat":"yyyy-MM-dd","expr":"imp_date","isCreateDimension":0,"type":"time","typeParams":{"isPrimary":"true","timeGranularity":"day"}},{"bizName":"act_area","dateFormat":"yyyy-MM-dd","expr":"act_area","isCreateDimension":1,"name":"活跃区域","type":"categorical"},{"bizName":"song_name","dateFormat":"yyyy-MM-dd","expr":"song_name","isCreateDimension":1,"name":"代表作","type":"categorical"},{"bizName":"genre","dateFormat":"yyyy-MM-dd","expr":"genre","isCreateDimension":1,"name":"风格","type":"categorical"}],"identifiers":[{"bizName":"singer_name","name":"歌手名","type":"primary"}],"measures":[{"agg":"sum","bizName":"music_down_cnt","expr":"down_cnt","isCreateMetric":1,"name":"下载量"},{"agg":"sum","bizName":"music_js_play_cnt","expr":"js_play_cnt","isCreateMetric":1,"name":"播放量"},{"agg":"sum","bizName":"music_favor_cnt","expr":"favor_cnt","isCreateMetric":1,"name":"收藏量"}],"queryType":"sql_query","sqlQuery":"SELECT imp_date,singer_name,act_area,song_name,genre,js_play_cnt,down_cnt,favor_cnt FROM singer "}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource_rela (id , model_id, `datasource_from`, datasource_to, join_key, created_at, created_by, updated_at, updated_by ) VALUES(1, 1, 1, 2, 'user_name', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource_rela (id , model_id, `datasource_from`, datasource_to, join_key, created_at, created_by, updated_at, updated_by ) VALUES(2, 1, 1, 3, 'user_name', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource_rela (id , model_id, `datasource_from`, datasource_to, join_key, created_at, created_by, updated_at, updated_by ) VALUES(3, 1, 2, 3, 'user_name', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type, dim_value_maps) VALUES(1, 1, 3, '部门', 'department', '部门', 1, 0, 'categorical', NULL, 'department', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY', '[{"alias":["人力资源","人力"],"bizName":"人力资源","techName":"HR"},{"alias":["营销","销售"],"bizName":"营销部门","techName":"sales"}]'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(2, 1, 1, '用户名', 'user_name', '用户名', 1, 0, 'primary', NULL, 'user_name', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(3, 1, 2, '页面', 'page', '页面', 1, 2, 'categorical', NULL, 'page', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(4, 2, 4, '活跃区域', 'act_area', '活跃区域', 1, 2, 'categorical', NULL, 'act_area', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(5, 2, 4, '代表作', 'song_name', '代表作', 1, 2, 'categorical', NULL, 'song_name', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(6, 2, 4, '风格', 'genre', '风格', 1, 2, 'categorical', NULL, 'genre', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(7, 2, 4, '歌手名', 'singer_name', '歌手名', 1, 2, 'categorical', NULL, 'singer_name', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(1, 1, '停留时长', 'stay_hours', '停留时长', 1, 2, 'ATOMIC', '{"expr":"s2_stay_time_statis_stay_hours","measures":[{"agg":"sum","expr":"stay_hours","isCreateMetric":1,"datasourceId":1,"bizName":"s2_stay_time_statis_stay_hours","name":"s2_stay_time_statis_stay_hours"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(2, 1, '访问次数', 'pv', '访问次数', 1, 0, 'ATOMIC', ' {"expr":"s2_pv_uv_statis_pv","measures":[{"agg":"sum","bizName":"s2_pv_uv_statis_pv","datasourceId":2,"expr":"pv","isCreateMetric":1,"name":"s2_pv_uv_statis_pv"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(3, 1, '访问人数', 'uv', '访问人数', 1, 0, 'ATOMIC', ' {"expr":"s2_pv_uv_statis_uv","measures":[{"agg":"count_distinct","bizName":"s2_pv_uv_statis_uv","datasourceId":2,"expr":"uv","isCreateMetric":1,"name":"s2_pv_uv_statis_uv"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(4, 2, '播放量', 'js_play_cnt', '播放量', 1, 2, 'ATOMIC', '{"expr":"music_js_play_cnt","measures":[{"agg":"sum","expr":"js_play_cnt","isCreateMetric":1,"datasourceId":4,"bizName":"music_js_play_cnt","name":"music_js_play_cnt"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(5, 2, '下载量', 'down_cnt', '下载量', 1, 0, 'ATOMIC', ' {"expr":"music_down_cnt","measures":[{"agg":"sum","bizName":"music_down_cnt","datasourceId":4,"expr":"down_cnt","isCreateMetric":1,"name":"music_down_cnt"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(6, 2, '收藏量', 'favor_cnt', '收藏量', 1, 0, 'ATOMIC', ' {"expr":"music_favor_cnt","measures":[{"agg":"sum","bizName":"music_favor_cnt","datasourceId":4,"expr":"favor_cnt","isCreateMetric":1,"name":"music_favor_cnt"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); - insert into s2_available_date_info(`item_id` ,`type` ,`date_format` ,`start_date` ,`end_date` ,`unavailable_date` ,`created_at` ,`created_by` ,`updated_at` ,`updated_by` ) values (1, 'dimension', 'yyyy-MM-dd', DATEADD('DAY', -28, CURRENT_DATE()), DATEADD('DAY', -1, CURRENT_DATE()), '[]', '2023-06-01', 'admin', '2023-06-01', 'admin'); insert into s2_available_date_info(`item_id` ,`type` ,`date_format` ,`start_date` ,`end_date` ,`unavailable_date` ,`created_at` ,`created_by` ,`updated_at` ,`updated_by` ) @@ -38,11 +12,6 @@ values (2, 'dimension', 'yyyy-MM-dd', DATEADD('DAY', -28, CURRENT_DATE()), DATEA insert into s2_available_date_info(`item_id` ,`type` ,`date_format` ,`start_date` ,`end_date` ,`unavailable_date` ,`created_at` ,`created_by` ,`updated_at` ,`updated_by` ) values (3, 'dimension', 'yyyy-MM-dd', DATEADD('DAY', -28, CURRENT_DATE()), DATEADD('DAY', -1, CURRENT_DATE()), '[]', '2023-06-01', 'admin', '2023-06-01', 'admin'); -insert into s2_auth_groups (group_id, config) -values (1, '{"modelId":"1","name":"admin-permission","groupId":1,"authRules":[{"metrics":["stay_hours"],"dimensions":["page"]}],"dimensionFilters":[""],"dimensionFilterDescription":"授权admin 页面和停留时长权限","authorizedUsers":["admin"],"authorizedDepartmentIds":[]}'); -insert into s2_auth_groups (group_id, config) -values (2, '{"modelId":"1","name":"tom_sales_permission","groupId":2,"authRules":[{"metrics":["stay_hours"],"dimensions":["page"]}],"dimensionFilters":["department in (''sales'')"],"dimensionFilterDescription":"部门 in [sales]", "authorizedUsers":["tom"],"authorizedDepartmentIds":[]}'); - -- sample data INSERT INTO singer (imp_date,singer_name,act_area, song_name,genre,js_play_cnt,down_cnt,favor_cnt) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '周杰伦', '港台','青花瓷','国风',1000000,1000000,1000000); INSERT INTO singer (imp_date,singer_name,act_area, song_name,genre,js_play_cnt,down_cnt,favor_cnt) VALUES (DATEADD('DAY', -5, CURRENT_DATE()), '周杰伦', '港台','青花瓷','国风',1000000,1000000,1000000); @@ -1108,3 +1077,35 @@ INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES ( INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES (DATEADD('DAY', -19, CURRENT_DATE()), 'alice', '0.8131712486302015', 'p2'); INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES (DATEADD('DAY', -15, CURRENT_DATE()), 'lucy', '0.8124302447925607', 'p4'); INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES (DATEADD('DAY', -8, CURRENT_DATE()), 'lucy', '0.039935860913407284', 'p2'); + + + +insert into genre(g_name,rating,most_popular_in) VALUES ('tagore',8,'孟加拉国'); +insert into genre(g_name,rating,most_popular_in) VALUES ('nazrul',7,'孟加拉国'); +insert into genre(g_name,rating,most_popular_in) VALUES ('民间',9,'锡尔赫特、吉大港、库斯蒂亚'); +insert into genre(g_name,rating,most_popular_in) VALUES ('现代',8,'孟加拉国'); +insert into genre(g_name,rating,most_popular_in) VALUES ('蓝调',7,'加拿大'); +insert into genre(g_name,rating,most_popular_in) VALUES ('流行',9,'美国'); + +insert into artist(artist_name,country,gender,g_name) VALUES ('Shrikanta','印度','男性','tagore'); +insert into artist(artist_name,country,gender,g_name) VALUES ('Prity','孟加拉国','女性','nazrul'); +insert into artist(artist_name,country,gender,g_name) VALUES ('Farida','孟加拉国','女性','民间'); +insert into artist(artist_name,country,gender,g_name) VALUES ('Topu','印度','女性','现代'); +insert into artist(artist_name,country,gender,g_name) VALUES ('Enrique','美国','男性','蓝调'); +insert into artist(artist_name,country,gender,g_name) VALUES ('Michel','英国','男性','流行'); + +insert into files(f_id,artist_name,file_size,duration,formats) VALUES (1,'Shrikanta','3.78 MB','3:45','mp4'); +insert into files(f_id,artist_name,file_size,duration,formats) VALUES (2,'Prity','4.12 MB','2:56','mp3'); +insert into files(f_id,artist_name,file_size,duration,formats) VALUES (3,'Farida','3.69 MB','4:12','mp4'); +insert into files(f_id,artist_name,file_size,duration,formats) VALUES (4,'Enrique','4.58 MB','5:23','mp4'); +insert into files(f_id,artist_name,file_size,duration,formats) VALUES (5,'Michel','5.10 MB','4:34','mp3'); +insert into files(f_id,artist_name,file_size,duration,formats) VALUES (6,'Topu','4.10 MB','4:30','mp4'); + +insert into song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'Tumi 长袍 尼罗布','Shrikanta','印度',1,'tagore',8,'孟加拉语','28-AUG-2011',1080); +insert into song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'舒克诺 帕塔尔 努普尔 帕埃','Prity','孟加拉国',2,'nazrul',5,'孟加拉语','21-SEP-1997',512); +insert into song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'阿米·奥帕尔·霍伊','Farida','孟加拉国',3,'民间',7,'孟加拉语','7-APR-2001',320); +insert into song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'我的爱','Enrique','美国',4,'蓝调',6,'英文','24-JAN-2007',1080); +insert into song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'打败它','Michel','英国',5,'流行',8,'英文','17-MAR-2002',720); +insert into song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'阿杰伊阿卡什','Topu','印度',6,'现代',10,'孟加拉语','27-MAR-2004',320); + +-- benchmark diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 1370bc82d..087686e7e 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -414,4 +414,47 @@ COMMENT ON TABLE s2_dictionary_task IS 'dictionary task information table'; +-- benchmark +CREATE TABLE IF NOT EXISTS `genre` ( + `g_name` varchar(20) NOT NULL , -- genre name + `rating` INT , + `most_popular_in` varchar(50) , + PRIMARY KEY (`g_name`) + ); +COMMENT ON TABLE genre IS 'genre'; + +CREATE TABLE IF NOT EXISTS `artist` ( + `artist_name` varchar(50) NOT NULL , -- genre name + `country` varchar(20) , + `gender` varchar(20) , + `g_name` varchar(50) + ); +COMMENT ON TABLE artist IS 'artist'; + +CREATE TABLE IF NOT EXISTS `files` ( + `f_id` bigINT NOT NULL, + `artist_name` varchar(50) , + `file_size` varchar(20) , + `duration` varchar(20) , + `formats` varchar(20) , + PRIMARY KEY (`f_id`) + ); +COMMENT ON TABLE files IS 'files'; + +CREATE TABLE IF NOT EXISTS `song` ( + `imp_date` varchar(50) , + `song_name` varchar(50) , + `artist_name` varchar(50) , + `country` varchar(20) , + `f_id` bigINT , + `g_name` varchar(20) , + `rating` INT , + `languages` varchar(20) , + `releasedate` varchar(50) , + `resolution` bigINT NOT NULL + ); +COMMENT ON TABLE song IS 'song'; + +-- benchmark + diff --git a/launchers/standalone/src/main/resources/hanlp.properties b/launchers/standalone/src/main/resources/hanlp.properties index 9d91904eb..8faa512a4 100644 --- a/launchers/standalone/src/main/resources/hanlp.properties +++ b/launchers/standalone/src/main/resources/hanlp.properties @@ -1,2 +1,2 @@ root=. -CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_1_3.txt; \ No newline at end of file +CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_1_3.txt;data/dictionary/custom/benchmark_cspider.txt; diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/benchmark/CSpider.java b/launchers/standalone/src/test/java/com/tencent/supersonic/benchmark/CSpider.java new file mode 100644 index 000000000..f85f0588b --- /dev/null +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/benchmark/CSpider.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.benchmark; + +import org.junit.Test; + +public class CSpider { + @Test + public void case1(){ + + } +} diff --git a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java index 4a980967d..d44c5cd55 100644 --- a/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java +++ b/launchers/standalone/src/test/java/com/tencent/supersonic/integration/MetricQueryTest.java @@ -210,7 +210,7 @@ public class MetricQueryTest extends BaseQueryTest { ChatConfigEditReqReq extendEditCmd = new ChatConfigEditReqReq(); BeanUtils.copyProperties(chatConfig, extendEditCmd); // add blacklist - List blackMetrics = Arrays.asList(3L); + List blackMetrics = Arrays.asList(2L); extendEditCmd.getChatAggConfig().getVisibility().setBlackMetricIdList(blackMetrics); configService.editConfig(extendEditCmd, User.getFakeUser()); diff --git a/launchers/standalone/src/test/resources/db/data-h2.sql b/launchers/standalone/src/test/resources/db/data-h2.sql index 10f6a3ef5..d161b9c14 100644 --- a/launchers/standalone/src/test/resources/db/data-h2.sql +++ b/launchers/standalone/src/test/resources/db/data-h2.sql @@ -4,32 +4,6 @@ insert into s2_user (id, `name`, password, display_name, email) values (2, 'jack insert into s2_user (id, `name`, password, display_name, email) values (3, 'tom','123456','tom','tom@xx.com'); insert into s2_user (id, `name`, password, display_name, email) values (4, 'lucy','123456','lucy','lucy@xx.com'); --- sample models -insert into s2_domain (id, `name`, biz_name, parent_id, status, created_at, created_by, updated_at, updated_by, `admin`, admin_org, viewer, view_org) VALUES(1, '超音数', 'supersonic', 0, 1, '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin', '', 'admin,tom,jack', 'admin' ); -insert into s2_model (id, `name`, biz_name, domain_id, created_at, created_by, updated_at, updated_by, `admin`, admin_org, is_open, viewer, view_org, entity) VALUES(1, '超音数', 'supersonic', 1, '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin', '', 0, 'admin,tom,jack', 'admin','' ); -insert into s2_model (id, `name`, biz_name, domain_id, created_at, created_by, updated_at, updated_by, `admin`, admin_org, is_open, viewer, view_org, entity) VALUES(2, '艺人库', 'singer', 1, '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin', '', 0, 'admin,tom,jack', 'admin','{"entityId": 7, "names": ["歌手", "艺人"]}' ); -insert into s2_database (id, `name`, description, `type` ,config ,created_at ,created_by ,updated_at ,updated_by, `admin`) VALUES(1, 'H2数据实例', '', 'h2', '{"password":"semantic","url":"jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false","userName":"root"}', '2023-05-24 00:00:00', 'admin', '2023-05-24 00:00:00', 'admin', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(1, 1, '停留时长统计', 's2_stay_time_statis', '停留时长统计', 1, '{"dimensions":[{"bizName":"imp_date","dateFormat":"yyyy-MM-dd","expr":"imp_date","isCreateDimension":0,"type":"time","typeParams":{"isPrimary":"true","timeGranularity":"day"}},{"bizName":"page","dateFormat":"yyyy-MM-dd","expr":"page","isCreateDimension":0,"type":"categorical"}],"identifiers":[{"bizName":"user_name","name":"用户名","type":"primary"}],"measures":[{"agg":"sum","bizName":"s2_stay_time_statis_stay_hours","expr":"stay_hours","isCreateMetric":1,"name":"停留时长"}],"queryType":"sql_query","sqlQuery":"SELECT imp_date, page,user_name,stay_hours FROM s2_stay_time_statis"}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(2, 1, 'PVUV统计', 's2_pv_uv_statis', 'PVUV统计', 1, '{"dimensions":[{"bizName":"imp_date","dateFormat":"yyyy-MM-dd","expr":"imp_date","isCreateDimension":0,"type":"time","typeParams":{"isPrimary":"true","timeGranularity":"day"}},{"bizName":"page","dateFormat":"yyyy-MM-dd","expr":"page","isCreateDimension":0,"type":"categorical"}],"identifiers":[{"bizName":"user_name","name":"用户名","type":"primary"}],"measures":[{"agg":"sum","bizName":"s2_pv_uv_statis_pv","expr":"pv","isCreateMetric":1,"name":"访问次数"},{"agg":"count_distinct","bizName":"s2_pv_uv_statis_uv","expr":"uv","isCreateMetric":1,"name":"访问人数"}],"queryType":"sql_query","sqlQuery":"SELECT imp_date, user_name,page,1 as pv, user_name as uv FROM s2_pv_uv_statis"}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(3, 1, '用户部门', 'user_department', '用户部门', 1, '{"dimensions":[{"bizName":"department","dateFormat":"yyyy-MM-dd","expr":"department","isCreateDimension":1,"name":"部门","type":"categorical"}],"identifiers":[{"bizName":"user_name","name":"用户名","type":"primary"}],"measures":[],"queryType":"sql_query","sqlQuery":"SELECT user_name,department FROM s2_user_department"}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource (id , model_id, `name`, biz_name, description, database_id ,datasource_detail, created_at, created_by, updated_at, updated_by ) VALUES(4, 2, '艺人库', 'singer', '艺人库', 1, '{"dimensions":[{"bizName":"imp_date","dateFormat":"yyyy-MM-dd","expr":"imp_date","isCreateDimension":0,"type":"time","typeParams":{"isPrimary":"true","timeGranularity":"day"}},{"bizName":"act_area","dateFormat":"yyyy-MM-dd","expr":"act_area","isCreateDimension":1,"name":"活跃区域","type":"categorical"},{"bizName":"song_name","dateFormat":"yyyy-MM-dd","expr":"song_name","isCreateDimension":1,"name":"代表作","type":"categorical"},{"bizName":"genre","dateFormat":"yyyy-MM-dd","expr":"genre","isCreateDimension":1,"name":"风格","type":"categorical"}],"identifiers":[{"bizName":"singer_name","name":"歌手名","type":"primary"}],"measures":[{"agg":"sum","bizName":"music_down_cnt","expr":"down_cnt","isCreateMetric":1,"name":"下载量"},{"agg":"sum","bizName":"music_js_play_cnt","expr":"js_play_cnt","isCreateMetric":1,"name":"播放量"},{"agg":"sum","bizName":"music_favor_cnt","expr":"favor_cnt","isCreateMetric":1,"name":"收藏量"}],"queryType":"sql_query","sqlQuery":"SELECT imp_date,singer_name,act_area,song_name,genre,js_play_cnt,down_cnt,favor_cnt FROM singer "}', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource_rela (id , model_id, `datasource_from`, datasource_to, join_key, created_at, created_by, updated_at, updated_by ) VALUES(1, 1, 1, 2, 'user_name', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource_rela (id , model_id, `datasource_from`, datasource_to, join_key, created_at, created_by, updated_at, updated_by ) VALUES(2, 1, 1, 3, 'user_name', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_datasource_rela (id , model_id, `datasource_from`, datasource_to, join_key, created_at, created_by, updated_at, updated_by ) VALUES(3, 1, 2, 3, 'user_name', '2023-05-25 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type, dim_value_maps) VALUES(1, 1, 3, '部门', 'department', '部门', 1, 0, 'categorical', NULL, 'department', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY', '[{"alias":["人力资源","人力"],"bizName":"人力资源","techName":"HR"},{"alias":["营销","销售"],"bizName":"营销部门","techName":"sales"}]'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(2, 1, 1, '用户名', 'user_name', '用户名', 1, 0, 'primary', NULL, 'user_name', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(3, 1, 2, '页面', 'page', '页面', 1, 2, 'categorical', NULL, 'page', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(4, 2, 4, '活跃区域', 'act_area', '活跃区域', 1, 2, 'categorical', NULL, 'act_area', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(5, 2, 4, '代表作', 'song_name', '代表作', 1, 2, 'categorical', NULL, 'song_name', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(6, 2, 4, '风格', 'genre', '风格', 1, 2, 'categorical', NULL, 'genre', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_dimension (id , model_id, datasource_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, expr, created_at, created_by, updated_at, updated_by, semantic_type) VALUES(7, 2, 4, '歌手名', 'singer_name', '歌手名', 1, 2, 'categorical', NULL, 'singer_name', '2023-05-24 00:00:00', 'admin', '2023-05-25 00:00:00', 'admin', 'CATEGORY'); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(1, 1, '停留时长', 'stay_hours', '停留时长', 1, 2, 'ATOMIC', '{"expr":"s2_stay_time_statis_stay_hours","measures":[{"agg":"sum","expr":"stay_hours","isCreateMetric":1,"datasourceId":1,"bizName":"s2_stay_time_statis_stay_hours","name":"s2_stay_time_statis_stay_hours"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(2, 1, '访问次数', 'pv', '访问次数', 1, 0, 'ATOMIC', ' {"expr":"s2_pv_uv_statis_pv","measures":[{"agg":"sum","bizName":"s2_pv_uv_statis_pv","datasourceId":2,"expr":"pv","isCreateMetric":1,"name":"s2_pv_uv_statis_pv"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(3, 1, '访问人数', 'uv', '访问人数', 1, 0, 'ATOMIC', ' {"expr":"s2_pv_uv_statis_uv","measures":[{"agg":"count_distinct","bizName":"s2_pv_uv_statis_uv","datasourceId":2,"expr":"uv","isCreateMetric":1,"name":"s2_pv_uv_statis_uv"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(4, 2, '播放量', 'js_play_cnt', '播放量', 1, 2, 'ATOMIC', '{"expr":"music_js_play_cnt","measures":[{"agg":"sum","expr":"js_play_cnt","isCreateMetric":1,"datasourceId":4,"bizName":"music_js_play_cnt","name":"music_js_play_cnt"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(5, 2, '下载量', 'down_cnt', '下载量', 1, 0, 'ATOMIC', ' {"expr":"music_down_cnt","measures":[{"agg":"sum","bizName":"music_down_cnt","datasourceId":4,"expr":"down_cnt","isCreateMetric":1,"name":"music_down_cnt"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); -insert into s2_metric (id, model_id, `name`, biz_name, description, status, sensitive_level, `type`, type_params, created_at, created_by, updated_at, updated_by, data_format_type, data_format) VALUES(6, 2, '收藏量', 'favor_cnt', '收藏量', 1, 0, 'ATOMIC', ' {"expr":"music_favor_cnt","measures":[{"agg":"sum","bizName":"music_favor_cnt","datasourceId":4,"expr":"favor_cnt","isCreateMetric":1,"name":"music_favor_cnt"}]}' , '2023-05-24 17:00:00', 'admin', '2023-05-25 00:00:00', 'admin', NULL, NULL ); - insert into s2_available_date_info(`item_id` ,`type` ,`date_format` ,`start_date` ,`end_date` ,`unavailable_date` ,`created_at` ,`created_by` ,`updated_at` ,`updated_by` ) values (1, 'dimension', 'yyyy-MM-dd', DATEADD('DAY', -28, CURRENT_DATE()), DATEADD('DAY', -1, CURRENT_DATE()), '[]', '2023-06-01', 'admin', '2023-06-01', 'admin'); insert into s2_available_date_info(`item_id` ,`type` ,`date_format` ,`start_date` ,`end_date` ,`unavailable_date` ,`created_at` ,`created_by` ,`updated_at` ,`updated_by` ) @@ -37,11 +11,6 @@ values (2, 'dimension', 'yyyy-MM-dd', DATEADD('DAY', -28, CURRENT_DATE()), DATEA insert into s2_available_date_info(`item_id` ,`type` ,`date_format` ,`start_date` ,`end_date` ,`unavailable_date` ,`created_at` ,`created_by` ,`updated_at` ,`updated_by` ) values (3, 'dimension', 'yyyy-MM-dd', DATEADD('DAY', -28, CURRENT_DATE()), DATEADD('DAY', -1, CURRENT_DATE()), '[]', '2023-06-01', 'admin', '2023-06-01', 'admin'); -insert into s2_auth_groups (group_id, config) -values (1, '{"modelId":"1","name":"admin-permission","groupId":1,"authRules":[{"metrics":["stay_hours"],"dimensions":["page"]}],"dimensionFilters":[""],"dimensionFilterDescription":"授权admin 页面和停留时长权限","authorizedUsers":["admin"],"authorizedDepartmentIds":[]}'); -insert into s2_auth_groups (group_id, config) -values (2, '{"modelId":"1","name":"tom_sales_permission","groupId":2,"authRules":[{"metrics":["stay_hours"],"dimensions":["page"]}],"dimensionFilters":["department in (''sales'')"],"dimensionFilterDescription":"开通 tom sales部门权限", "authorizedUsers":["tom"],"authorizedDepartmentIds":[]}'); - -- sample data INSERT INTO singer (imp_date,singer_name,act_area, song_name,genre,js_play_cnt,down_cnt,favor_cnt) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '周杰伦', '中国','青花瓷','流行',1000000,1000000,1000000); INSERT INTO singer (imp_date,singer_name,act_area, song_name,genre,js_play_cnt,down_cnt,favor_cnt) VALUES (DATEADD('DAY', -5, CURRENT_DATE()), '周杰伦', '中国','青花瓷','流行',1000000,1000000,1000000); diff --git a/pom.xml b/pom.xml index c9bdfff2d..2b605c8f7 100644 --- a/pom.xml +++ b/pom.xml @@ -65,6 +65,11 @@ 4.5.1 4.5 0.7.5-SNAPSHOT + + 2.30.0 + + + 22.3.0 @@ -101,6 +106,15 @@ + + + spotless-python + + src/**/*.py + + + + @@ -147,6 +161,10 @@ org.apache.maven.plugins maven-checkstyle-plugin + + com.diffplug.spotless + spotless-maven-plugin + @@ -185,6 +203,31 @@ + + com.diffplug.spotless + spotless-maven-plugin + ${maven.plugin.spotless.version} + + + true + + + + ${spotless.python.includes} + + + ${spotless.python.black.version} + + + + + + + check + + + + diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/enums/IdentifyTypeEnum.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/enums/IdentifyTypeEnum.java new file mode 100644 index 000000000..595b729d4 --- /dev/null +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/enums/IdentifyTypeEnum.java @@ -0,0 +1,9 @@ +package com.tencent.supersonic.semantic.api.model.enums; + +public enum IdentifyTypeEnum { + + primary, + + foreign, + +} \ No newline at end of file diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/enums/SemanticTypeEnum.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/enums/SemanticTypeEnum.java new file mode 100644 index 000000000..04ed6b09f --- /dev/null +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/enums/SemanticTypeEnum.java @@ -0,0 +1,10 @@ +package com.tencent.supersonic.semantic.api.model.enums; + +public enum SemanticTypeEnum { + + CATEGORY, + ID, + DATE, + NUMBER + +} \ No newline at end of file diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Dim.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Dim.java index 618b7da33..63396ac77 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Dim.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Dim.java @@ -25,6 +25,13 @@ public class Dim { private String bizName; + public Dim(String name, String bizName, String type, Integer isCreateDimension) { + this.name = name; + this.type = type; + this.isCreateDimension = isCreateDimension; + this.bizName = bizName; + } + public static Dim getDefault() { return new Dim("日期", "time", "2023-05-28", Constants.DAY_FORMAT, diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/DimensionTimeTypeParams.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/DimensionTimeTypeParams.java index db3a265bd..e852e0567 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/DimensionTimeTypeParams.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/DimensionTimeTypeParams.java @@ -10,8 +10,8 @@ import lombok.NoArgsConstructor; @NoArgsConstructor public class DimensionTimeTypeParams { - private String isPrimary; + private String isPrimary = "true"; - private String timeGranularity; + private String timeGranularity = "day"; } diff --git a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Measure.java b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Measure.java index 17efa7938..a6c028f46 100644 --- a/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Measure.java +++ b/semantic/api/src/main/java/com/tencent/supersonic/semantic/api/model/pojo/Measure.java @@ -28,5 +28,10 @@ public class Measure { private Long datasourceId; - + public Measure(String name, String bizName, String agg, Integer isCreateMetric) { + this.name = name; + this.agg = agg; + this.isCreateMetric = isCreateMetric; + this.bizName = bizName; + } } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java index 067b8af44..e7b517ab4 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/CatalogImpl.java @@ -1,26 +1,30 @@ package com.tencent.supersonic.semantic.model.application; import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; -import com.tencent.supersonic.semantic.api.model.response.MetricResp; import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; -import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.response.DatasourceResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ItemDateResp; +import com.tencent.supersonic.semantic.api.model.response.MeasureResp; +import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.ModelResp; import com.tencent.supersonic.semantic.api.model.yaml.DatasourceYamlTpl; import com.tencent.supersonic.semantic.api.model.yaml.DimensionYamlTpl; import com.tencent.supersonic.semantic.api.model.yaml.MetricYamlTpl; -import com.tencent.supersonic.semantic.model.domain.DatabaseService; -import com.tencent.supersonic.semantic.model.domain.ModelService; -import com.tencent.supersonic.semantic.model.domain.DimensionService; -import com.tencent.supersonic.semantic.model.domain.DatasourceService; -import com.tencent.supersonic.semantic.model.domain.MetricService; import com.tencent.supersonic.semantic.model.domain.Catalog; +import com.tencent.supersonic.semantic.model.domain.DatabaseService; +import com.tencent.supersonic.semantic.model.domain.DatasourceService; +import com.tencent.supersonic.semantic.model.domain.DimensionService; +import com.tencent.supersonic.semantic.model.domain.MetricService; +import com.tencent.supersonic.semantic.model.domain.ModelService; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Set; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; @Slf4j @Component @@ -96,4 +100,26 @@ public class CatalogImpl implements Catalog { public ItemDateResp getItemDate(ItemDateFilter dimension, ItemDateFilter metric) { return datasourceService.getItemDate(dimension, metric); } + + @Override + public String getAgg(Long modelId, String metricBizName) { + List metricResps = getMetrics(modelId); + if (!CollectionUtils.isEmpty(metricResps)) { + Optional metric = metricResps.stream() + .filter(m -> m.getBizName().equalsIgnoreCase(metricBizName)).findFirst(); + if (metric.isPresent() && Objects.nonNull(metric.get().getTypeParams()) && !CollectionUtils.isEmpty( + metric.get().getTypeParams().getMeasures())) { + List measureRespList = datasourceService.getMeasureListOfModel(modelId); + if (!CollectionUtils.isEmpty(measureRespList)) { + String measureName = metric.get().getTypeParams().getMeasures().get(0).getBizName(); + Optional measure = measureRespList.stream() + .filter(m -> m.getBizName().equalsIgnoreCase(measureName)).findFirst(); + if (measure.isPresent()) { + return measure.get().getAgg(); + } + } + } + } + return ""; + } } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java index 5c5cf2bca..aedaf2250 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/application/ModelServiceImpl.java @@ -221,7 +221,7 @@ public class ModelServiceImpl implements ModelService { @Override public Map getModelFullPathMap() { - return getModelList().stream().collect(Collectors.toMap(ModelResp::getId, + return getModelList().stream().filter(m -> m != null).collect(Collectors.toMap(ModelResp::getId, ModelResp::getFullPath, (k1, k2) -> k1)); } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/Catalog.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/Catalog.java index b1c7fca3d..ab4c60127 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/Catalog.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/Catalog.java @@ -1,14 +1,14 @@ package com.tencent.supersonic.semantic.model.domain; import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; -import com.tencent.supersonic.semantic.api.model.yaml.DatasourceYamlTpl; -import com.tencent.supersonic.semantic.api.model.yaml.DimensionYamlTpl; -import com.tencent.supersonic.semantic.api.model.yaml.MetricYamlTpl; import com.tencent.supersonic.semantic.api.model.response.DatabaseResp; import com.tencent.supersonic.semantic.api.model.response.DatasourceResp; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ItemDateResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.yaml.DatasourceYamlTpl; +import com.tencent.supersonic.semantic.api.model.yaml.DimensionYamlTpl; +import com.tencent.supersonic.semantic.api.model.yaml.MetricYamlTpl; import java.util.List; import java.util.Map; import java.util.Set; @@ -16,6 +16,7 @@ import java.util.Set; public interface Catalog { DatabaseResp getDatabase(Long id); + DatabaseResp getDatabaseByModelId(Long modelId); List getDatasourceList(Long modelId); @@ -36,4 +37,6 @@ public interface Catalog { ItemDateResp getItemDate(ItemDateFilter dimension, ItemDateFilter metric); + String getAgg(Long modelId, String metricBizName); + } diff --git a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/H2Adaptor.java b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/H2Adaptor.java index b2295d397..7695ad264 100644 --- a/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/H2Adaptor.java +++ b/semantic/model/src/main/java/com/tencent/supersonic/semantic/model/domain/adaptor/engineadapter/H2Adaptor.java @@ -29,8 +29,13 @@ public class H2Adaptor extends EngineAdaptor { @Override public String getColumnMetaQueryTpl() { - return "SELECT COLUMN_NAME AS name, DATA_TYPE AS dataType\n" - + "FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA ='%s' AND TABLE_NAME = '%s'"; + return "SELECT COLUMN_NAME AS name, " + + " case DATA_TYPE" + + " when '12' then 'varchar'" + + " when '-5' then 'integer'" + + " when '8' then 'double'" + + " end AS dataType" + + " FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA ='%s' AND TABLE_NAME = '%s'"; } @Override diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/dsl/Identify.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/dsl/Identify.java index 8dc4173d5..61458caca 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/dsl/Identify.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/dsl/Identify.java @@ -10,6 +10,10 @@ import lombok.NoArgsConstructor; @NoArgsConstructor public class Identify { + public enum Type { + PRIMARY, FOREIGN + } + private String name; // primary or foreign diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/TableView.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/TableView.java index db6893612..0659b74d3 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/TableView.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/TableView.java @@ -1,5 +1,6 @@ package com.tencent.supersonic.semantic.query.parser.calcite.sql; +import com.tencent.supersonic.semantic.query.parser.calcite.dsl.DataSource; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -25,6 +26,8 @@ public class TableView { private String alias; private List primary; + private DataSource dataSource; + public SqlNode build() { measure.addAll(dimension); SqlNodeList dimensionNodeList = null; diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/IdentifyNode.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/IdentifyNode.java index 796edb35a..1203ad739 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/IdentifyNode.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/node/IdentifyNode.java @@ -1,6 +1,11 @@ package com.tencent.supersonic.semantic.query.parser.calcite.sql.node; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Identify; +import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Identify.Type; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.validate.SqlValidatorScope; @@ -9,4 +14,28 @@ public class IdentifyNode extends SemanticNode { public static SqlNode build(Identify identify, SqlValidatorScope scope) throws Exception { return parse(identify.getName(), scope); } + + public static Set getIdentifyNames(List identifies, Identify.Type type) { + return identifies.stream().filter(i -> type.name().equalsIgnoreCase(i.getType())).map(i -> i.getName()) + .collect(Collectors.toSet()); + + } + + public static boolean isForeign(String name, List identifies) { + Optional identify = identifies.stream().filter(i -> i.getName().equalsIgnoreCase(name)) + .findFirst(); + if (identify.isPresent()) { + return Type.FOREIGN.name().equalsIgnoreCase(identify.get().getType()); + } + return false; + } + + public static boolean isPrimary(String name, List identifies) { + Optional identify = identifies.stream().filter(i -> i.getName().equalsIgnoreCase(name)) + .findFirst(); + if (identify.isPresent()) { + return Type.PRIMARY.name().equalsIgnoreCase(identify.get().getType()); + } + return false; + } } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java index a83f7d70e..c61dca9a0 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/parser/calcite/sql/render/JoinRender.java @@ -5,6 +5,7 @@ import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Constants; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.DataSource; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Dimension; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Identify; +import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Identify.Type; import com.tencent.supersonic.semantic.query.parser.calcite.dsl.Metric; import com.tencent.supersonic.semantic.query.parser.calcite.schema.SemanticSchema; import com.tencent.supersonic.semantic.query.parser.calcite.sql.Renderer; @@ -12,15 +13,20 @@ import com.tencent.supersonic.semantic.query.parser.calcite.sql.TableView; import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.AggFunctionNode; import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.DataSourceNode; import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.FilterNode; +import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.IdentifyNode; import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.MetricNode; import com.tencent.supersonic.semantic.query.parser.calcite.sql.node.SemanticNode; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; import lombok.extern.slf4j.Slf4j; @@ -33,6 +39,7 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.validate.SqlValidatorScope; +import org.springframework.util.CollectionUtils; @Slf4j public class JoinRender extends Renderer { @@ -41,6 +48,7 @@ public class JoinRender extends Renderer { public void render(MetricReq metricCommand, List dataSources, SqlValidatorScope scope, SemanticSchema schema, boolean nonAgg) throws Exception { String queryWhere = metricCommand.getWhere(); + dataSources = getOrderSource(dataSources); Set whereFields = new HashSet<>(); List fieldWhere = new ArrayList<>(); if (queryWhere != null && !queryWhere.isEmpty()) { @@ -95,6 +103,7 @@ public class JoinRender extends Renderer { String alias = Constants.JOIN_TABLE_PREFIX + dataSource.getName(); tableView.setAlias(alias); tableView.setPrimary(primary); + tableView.setDataSource(dataSource); if (left == null) { leftTable = tableView; left = SemanticNode.buildAs(tableView.getAlias(), getTable(tableView, scope)); @@ -246,7 +255,7 @@ public class JoinRender extends Renderer { private SqlNode getCondition(TableView left, TableView right, DataSource dataSource, SemanticSchema schema, SqlValidatorScope scope) throws Exception { - log.info(left.getClass().toString()); + Set selectLeft = SemanticNode.getSelect(left.getTable()); Set selectRight = SemanticNode.getSelect(right.getTable()); selectLeft.retainAll(selectRight); @@ -255,6 +264,16 @@ public class JoinRender extends Renderer { if (!SourceRender.isDimension(on, dataSource, schema)) { continue; } + if (IdentifyNode.isForeign(on, left.getDataSource().getIdentifiers())) { + if (!IdentifyNode.isPrimary(on, right.getDataSource().getIdentifiers())) { + continue; + } + } + if (IdentifyNode.isForeign(on, right.getDataSource().getIdentifiers())) { + if (!IdentifyNode.isPrimary(on, left.getDataSource().getIdentifiers())) { + continue; + } + } List ons = new ArrayList<>(); ons.add(SemanticNode.parse(left.getAlias() + "." + on, scope)); ons.add(SemanticNode.parse(right.getAlias() + "." + on, scope)); @@ -276,4 +295,85 @@ public class JoinRender extends Renderer { } return condition; } + + private List getOrderSource(List dataSources) throws Exception { + if (CollectionUtils.isEmpty(dataSources) || dataSources.size() <= 2) { + return dataSources; + } + Map> next = new HashMap<>(); + Map visited = new HashMap<>(); + Map> dataSourceIdentifies = new HashMap<>(); + dataSources.stream().forEach(d -> { + next.put(d.getName(), new HashSet<>()); + visited.put(d.getName(), false); + dataSourceIdentifies.put(d.getName(), d.getIdentifiers()); + }); + int cnt = dataSources.size(); + List>> dataSourceIdentifyList = dataSourceIdentifies.entrySet().stream() + .collect( + Collectors.toList()); + for (int i = 0; i < cnt; i++) { + for (int j = i + 1; j < cnt; j++) { + Set primaries = IdentifyNode.getIdentifyNames(dataSourceIdentifyList.get(i).getValue(), + Type.PRIMARY); + Set foreign = IdentifyNode.getIdentifyNames(dataSourceIdentifyList.get(i).getValue(), + Type.FOREIGN); + Set nextPrimaries = IdentifyNode.getIdentifyNames(dataSourceIdentifyList.get(j).getValue(), + Type.PRIMARY); + Set nextForeign = IdentifyNode.getIdentifyNames(dataSourceIdentifyList.get(j).getValue(), + Type.FOREIGN); + Set nextAll = new HashSet<>(); + nextAll.addAll(nextPrimaries); + nextAll.addAll(nextForeign); + primaries.retainAll(nextPrimaries); + foreign.retainAll(nextPrimaries); + if (primaries.size() > 0 || foreign.size() > 0) { + next.get(dataSourceIdentifyList.get(i).getKey()).add(dataSourceIdentifyList.get(j).getKey()); + next.get(dataSourceIdentifyList.get(j).getKey()).add(dataSourceIdentifyList.get(i).getKey()); + } + + } + } + Queue paths = new ArrayDeque<>(); + for (String id : visited.keySet()) { + if (!visited.get(id)) { + joinOrder(cnt, id, next, paths, visited); + if (paths.size() >= cnt) { + break; + } + } + } + if (paths.size() < cnt) { + throw new Exception("datasource cant join,pls check identify :" + dataSources.stream() + .map(d -> d.getName()).collect( + Collectors.joining(","))); + } + List orderList = new ArrayList<>(paths); + Collections.sort(dataSources, new Comparator() { + @Override + public int compare(DataSource o1, DataSource o2) { + return orderList.indexOf(o1.getName()) - orderList.indexOf(o2.getName()); + } + }); + return dataSources; + } + + private static void joinOrder(int cnt, String id, Map> next, Queue orders, + Map visited) { + visited.put(id, true); + orders.add(id); + if (orders.size() >= cnt) { + return; + } + for (String nextId : next.get(id)) { + if (!visited.get(nextId)) { + joinOrder(cnt, nextId, next, orders, visited); + if (orders.size() >= cnt) { + return; + } + } + } + orders.poll(); + visited.put(id, false); + } } diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java new file mode 100644 index 000000000..e6a30a620 --- /dev/null +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/AuthCommonService.java @@ -0,0 +1,258 @@ +package com.tencent.supersonic.semantic.query.service; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthRes; +import com.tencent.supersonic.auth.api.authorization.pojo.AuthResGrp; +import com.tencent.supersonic.auth.api.authorization.pojo.DimensionFilter; +import com.tencent.supersonic.auth.api.authorization.request.QueryAuthResReq; +import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; +import com.tencent.supersonic.auth.api.authorization.service.AuthService; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.QueryAuthorization; +import com.tencent.supersonic.common.pojo.QueryColumn; +import com.tencent.supersonic.common.pojo.enums.AuthType; +import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException; +import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; +import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.model.response.ModelResp; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.model.domain.DimensionService; +import com.tencent.supersonic.semantic.model.domain.MetricService; +import com.tencent.supersonic.semantic.model.domain.ModelService; +import lombok.extern.slf4j.Slf4j; +import org.assertj.core.util.Sets; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; +import org.springframework.util.CollectionUtils; + +import java.text.SimpleDateFormat; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.HashSet; + +import java.util.stream.Collectors; + +@Service +@Slf4j +public class AuthCommonService { + private static final ObjectMapper MAPPER = new ObjectMapper().setDateFormat( + new SimpleDateFormat(Constants.DAY_FORMAT)); + @Autowired + private AuthService authService; + @Autowired + private DimensionService dimensionService; + @Autowired + private MetricService metricService; + + @Autowired + private ModelService modelService; + + public boolean doModelAdmin(User user, Long modelId) { + List modelListAdmin = modelService.getModelListWithAuth(user, null, AuthType.ADMIN); + if (CollectionUtils.isEmpty(modelListAdmin)) { + return false; + } else { + Map> id2modelResp = modelListAdmin.stream() + .collect(Collectors.groupingBy(SchemaItem::getId)); + return !CollectionUtils.isEmpty(id2modelResp) && id2modelResp.containsKey(modelId); + } + } + + public void doModelVisible(User user, Long modelId) { + Boolean visible = true; + List modelListVisible = modelService.getModelListWithAuth(user, null, AuthType.VISIBLE); + if (CollectionUtils.isEmpty(modelListVisible)) { + visible = false; + } else { + Map> id2domainDesc = modelListVisible.stream() + .collect(Collectors.groupingBy(SchemaItem::getId)); + if (!CollectionUtils.isEmpty(id2domainDesc) && !id2domainDesc.containsKey(modelId)) { + visible = false; + } + } + if (!visible) { + ModelResp modelResp = modelService.getModel(modelId); + String modelName = modelResp.getName(); + List admins = modelService.getModelAdmin(modelResp.getId()); + String message = String.format("您没有主题域[%s]权限,请联系管理员%s开通", modelName, admins); + throw new InvalidPermissionException(message); + } + + } + + public Set getHighSensitiveColsByModelId(Long modelId) { + Set highSensitiveCols = new HashSet<>(); + List highSensitiveDimensions = dimensionService.getHighSensitiveDimension(modelId); + List highSensitiveMetrics = metricService.getHighSensitiveMetric(modelId); + if (!CollectionUtils.isEmpty(highSensitiveDimensions)) { + highSensitiveDimensions.stream().forEach(dim -> highSensitiveCols.add(dim.getBizName())); + } + if (!CollectionUtils.isEmpty(highSensitiveMetrics)) { + highSensitiveMetrics.stream().forEach(metric -> highSensitiveCols.add(metric.getBizName())); + } + return highSensitiveCols; + } + + public AuthorizedResourceResp getAuthorizedResource(User user, Long domainId, + Set sensitiveResReq) { + List resourceReqList = new ArrayList<>(); + sensitiveResReq.forEach(res -> resourceReqList.add(new AuthRes(domainId.toString(), res))); + QueryAuthResReq queryAuthResReq = new QueryAuthResReq(); + queryAuthResReq.setResources(resourceReqList); + queryAuthResReq.setModelId(domainId + ""); + AuthorizedResourceResp authorizedResource = fetchAuthRes(queryAuthResReq, user); + log.info("user:{}, domainId:{}, after queryAuthorizedResources:{}", user.getName(), domainId, + authorizedResource); + return authorizedResource; + } + private AuthorizedResourceResp fetchAuthRes(QueryAuthResReq queryAuthResReq, User user) { + log.info("queryAuthResReq:{}", queryAuthResReq); + return authService.queryAuthorizedResources(queryAuthResReq, user); + } + public Set getAuthResNameSet(AuthorizedResourceResp authorizedResource, Long domainId) { + Set resAuthName = new HashSet<>(); + List authResGrpList = authorizedResource.getResources(); + authResGrpList.stream().forEach(authResGrp -> { + List cols = authResGrp.getGroup(); + if (!CollectionUtils.isEmpty(cols)) { + cols.stream().filter(col -> domainId.equals(Long.parseLong(col.getModelId()))) + .forEach(col -> resAuthName.add(col.getName())); + } + + }); + log.info("resAuthName:{}", resAuthName); + return resAuthName; + } + public boolean allSensitiveResReqIsOk(Set sensitiveResReq, Set resAuthSet) { + if (resAuthSet.containsAll(sensitiveResReq)) { + return true; + } + log.info("sensitiveResReq:{}, resAuthSet:{}", sensitiveResReq, resAuthSet); + return false; + } + + public QueryResultWithSchemaResp getQueryResultWithColumns(QueryResultWithSchemaResp resultWithColumns, + Long domainId, AuthorizedResourceResp authResource) { + addPromptInfoInfo(domainId, resultWithColumns, authResource, Sets.newHashSet()); + return resultWithColumns; + } + + public QueryResultWithSchemaResp desensitizationData(QueryResultWithSchemaResp raw, Set need2Apply) { + log.debug("start desensitizationData logic"); + if (CollectionUtils.isEmpty(need2Apply)) { + log.info("user has all sensitiveRes"); + return raw; + } + + List columns = raw.getColumns(); + + boolean doDesensitization = false; + for (QueryColumn queryColumn : columns) { + if (need2Apply.contains(queryColumn.getNameEn())) { + doDesensitization = true; + break; + } + } + if (!doDesensitization) { + return raw; + } + + QueryResultWithSchemaResp queryResultWithColumns = raw; + try { + queryResultWithColumns = deepCopyResult(raw); + } catch (Exception e) { + log.warn("deepCopyResult: ", e); + } + addAuthorizedSchemaInfo(queryResultWithColumns.getColumns(), need2Apply); + desensitizationInternal(queryResultWithColumns.getResultList(), need2Apply); + return queryResultWithColumns; + } + + private void addAuthorizedSchemaInfo(List columns, Set need2Apply) { + if (CollectionUtils.isEmpty(need2Apply)) { + return; + } + columns.stream().forEach(col -> { + if (need2Apply.contains(col.getNameEn())) { + col.setAuthorized(false); + } + }); + } + + private void desensitizationInternal(List> result, Set need2Apply) { + log.info("start desensitizationInternal logic"); + for (int i = 0; i < result.size(); i++) { + Map row = result.get(i); + Map newRow = new HashMap<>(); + for (String col : row.keySet()) { + if (need2Apply.contains(col)) { + newRow.put(col, "****"); + } else { + newRow.put(col, row.get(col)); + } + } + result.set(i, newRow); + } + } + + private QueryResultWithSchemaResp deepCopyResult(QueryResultWithSchemaResp raw) throws Exception { + QueryResultWithSchemaResp queryResultWithColumns = new QueryResultWithSchemaResp(); + BeanUtils.copyProperties(raw, queryResultWithColumns); + + List columns = new ArrayList<>(); + if (!CollectionUtils.isEmpty(raw.getColumns())) { + String columnsStr = MAPPER.writeValueAsString(raw.getColumns()); + columns = MAPPER.readValue(columnsStr, new TypeReference>() { + }); + queryResultWithColumns.setColumns(columns); + } + queryResultWithColumns.setColumns(columns); + + List> resultData = new ArrayList<>(); + if (!CollectionUtils.isEmpty(raw.getResultList())) { + for (Map line : raw.getResultList()) { + Map newLine = new HashMap<>(); + newLine.putAll(line); + resultData.add(newLine); + } + } + queryResultWithColumns.setResultList(resultData); + return queryResultWithColumns; + } + + public void addPromptInfoInfo(Long modelId, QueryResultWithSchemaResp queryResultWithColumns, + AuthorizedResourceResp authorizedResource, Set need2Apply) { + List filters = authorizedResource.getFilters(); + if (CollectionUtils.isEmpty(need2Apply) && CollectionUtils.isEmpty(filters)) { + return; + } + List admins = modelService.getModelAdmin(modelId); + if (!CollectionUtils.isEmpty(need2Apply)) { + String promptInfo = String.format("当前结果已经过脱敏处理, 申请权限请联系管理员%s", admins); + queryResultWithColumns.setQueryAuthorization(new QueryAuthorization(promptInfo)); + } + if (!CollectionUtils.isEmpty(filters)) { + log.debug("dimensionFilters:{}", filters); + ModelResp modelResp = modelService.getModel(modelId); + List exprList = new ArrayList<>(); + List descList = new ArrayList<>(); + filters.stream().forEach(filter -> { + descList.add(filter.getDescription()); + exprList.add(filter.getExpressions().toString()); + }); + String promptInfo = "当前结果已经过行权限过滤,详细过滤条件如下:%s, 申请权限请联系管理员%s"; + String message = String.format(promptInfo, CollectionUtils.isEmpty(descList) ? exprList : descList, admins); + + queryResultWithColumns.setQueryAuthorization( + new QueryAuthorization(modelResp.getName(), exprList, descList, message)); + log.info("queryResultWithColumns:{}", queryResultWithColumns); + } + } +} diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java index 298a0e897..d39287f61 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/service/QueryServiceImpl.java @@ -21,6 +21,7 @@ import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryMultiStructReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.api.query.response.ItemUseResp; +import com.tencent.supersonic.semantic.query.utils.DslPermissionAnnotation; import com.tencent.supersonic.semantic.query.executor.QueryExecutor; import com.tencent.supersonic.semantic.query.parser.convert.QueryReqConverter; import com.tencent.supersonic.semantic.query.persistence.pojo.QueryStatement; @@ -66,9 +67,16 @@ public class QueryServiceImpl implements QueryService { } @Override - public Object queryBySql(QueryDslReq querySqlCmd, User user) throws Exception { + @DslPermissionAnnotation + @SneakyThrows + public Object queryBySql(QueryDslReq querySqlCmd, User user) { statUtils.initStatInfo(querySqlCmd, user); - QueryStatement queryStatement = convertToQueryStatement(querySqlCmd, user); + QueryStatement queryStatement = new QueryStatement(); + try { + queryStatement = convertToQueryStatement(querySqlCmd, user); + } catch (Exception e) { + log.info("convertToQueryStatement has a exception:{}", e.toString()); + } QueryResultWithSchemaResp results = semanticQueryEngine.execute(queryStatement); statUtils.statInfo2DbAsync(TaskStatusEnum.SUCCESS); return results; diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java new file mode 100644 index 000000000..15d706c6f --- /dev/null +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslDataAspect.java @@ -0,0 +1,188 @@ +package com.tencent.supersonic.semantic.query.utils; + +import com.google.common.base.Strings; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.auth.api.authorization.response.AuthorizedResourceResp; +import com.tencent.supersonic.common.pojo.Constants; +import com.tencent.supersonic.common.pojo.exception.InvalidPermissionException; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserUpdateHelper; +import com.tencent.supersonic.semantic.api.model.response.DimensionResp; +import com.tencent.supersonic.semantic.api.model.response.ModelResp; +import com.tencent.supersonic.semantic.api.model.response.QueryResultWithSchemaResp; +import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; +import com.tencent.supersonic.semantic.model.domain.DimensionService; +import com.tencent.supersonic.semantic.model.domain.ModelService; +import com.tencent.supersonic.semantic.query.service.AuthCommonService; +import lombok.extern.slf4j.Slf4j; +import net.sf.jsqlparser.JSQLParserException; +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.parser.CCJSqlParserUtil; +import org.apache.commons.lang3.StringUtils; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.annotation.Around; +import org.aspectj.lang.annotation.Aspect; +import org.aspectj.lang.annotation.Pointcut; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.annotation.Order; +import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; + +import java.util.StringJoiner; +import java.util.Objects; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.HashSet; + +import java.util.stream.Collectors; + +import static com.tencent.supersonic.common.pojo.Constants.MINUS; + +@Component +@Aspect +@Order(1) +@Slf4j +public class DslDataAspect { + + @Autowired + private QueryStructUtils queryStructUtils; + @Autowired + private DimensionService dimensionService; + @Autowired + private ModelService modelService; + @Autowired + private AuthCommonService authCommonService; + @Value("${permission.data.enable:true}") + private Boolean permissionDataEnable; + + @Pointcut("@annotation(com.tencent.supersonic.semantic.query.utils.DslPermissionAnnotation)") + private void dslPermissionCheck() { + } + + @Around("dslPermissionCheck()") + public Object doAround(ProceedingJoinPoint joinPoint) throws Throwable { + log.info("dsl permission check!"); + Object[] objects = joinPoint.getArgs(); + QueryDslReq queryDslReq = (QueryDslReq) objects[0]; + User user = (User) objects[1]; + if (!permissionDataEnable) { + log.info("not to check dsl permission!"); + return joinPoint.proceed(); + } + if (Objects.isNull(user) || Strings.isNullOrEmpty(user.getName())) { + throw new RuntimeException("please provide user information"); + } + Long modelId = queryDslReq.getModelId(); + + //1. determine whether admin of the model + if (authCommonService.doModelAdmin(user, modelId)) { + return joinPoint.proceed(); + } + + // 2. determine whether the subject field is visible + authCommonService.doModelVisible(user, modelId); + + // 3. fetch data permission meta information + Set res4Privilege = queryStructUtils.getResNameEnExceptInternalCol(queryDslReq); + log.info("modelId:{}, res4Privilege:{}", modelId, res4Privilege); + + Set sensitiveResByModel = authCommonService.getHighSensitiveColsByModelId(modelId); + Set sensitiveResReq = res4Privilege.parallelStream() + .filter(sensitiveResByModel::contains).collect(Collectors.toSet()); + log.info("this query domainId:{}, sensitiveResReq:{}", modelId, sensitiveResReq); + + // query user privilege info + AuthorizedResourceResp authorizedResource = authCommonService + .getAuthorizedResource(user, modelId, sensitiveResReq); + // get sensitiveRes that user has privilege + Set resAuthSet = authCommonService.getAuthResNameSet(authorizedResource, modelId); + + // 4.if sensitive fields without permission are involved in filter, thrown an exception + doFilterCheckLogic(queryDslReq, resAuthSet, sensitiveResReq); + + // 5.row permission pre-filter + doRowPermission(queryDslReq, authorizedResource); + + // 6.proceed + QueryResultWithSchemaResp queryResultWithColumns = (QueryResultWithSchemaResp) joinPoint.proceed(); + + if (CollectionUtils.isEmpty(sensitiveResReq) || authCommonService + .allSensitiveResReqIsOk(sensitiveResReq, resAuthSet)) { + // if sensitiveRes is empty + log.info("sensitiveResReq is empty"); + return authCommonService.getQueryResultWithColumns(queryResultWithColumns, modelId, authorizedResource); + } + + // 6.if the column has no permission, hit * + Set need2Apply = sensitiveResReq.stream().filter(req -> !resAuthSet.contains(req)) + .collect(Collectors.toSet()); + QueryResultWithSchemaResp queryResultAfterDesensitization = authCommonService + .desensitizationData(queryResultWithColumns, need2Apply); + authCommonService.addPromptInfoInfo(modelId, queryResultAfterDesensitization, authorizedResource, need2Apply); + + return queryResultAfterDesensitization; + } + + private void doRowPermission(QueryDslReq queryDslReq, AuthorizedResourceResp authorizedResource) { + log.debug("start doRowPermission logic"); + StringJoiner joiner = new StringJoiner(" OR "); + List dimensionFilters = new ArrayList<>(); + if (!CollectionUtils.isEmpty(authorizedResource.getFilters())) { + authorizedResource.getFilters().stream() + .forEach(filter -> dimensionFilters.addAll(filter.getExpressions())); + } + + if (CollectionUtils.isEmpty(dimensionFilters)) { + log.debug("dimensionFilters is empty"); + return; + } + + dimensionFilters.stream().forEach(filter -> { + if (StringUtils.isNotEmpty(filter) && StringUtils.isNotEmpty(filter.trim())) { + joiner.add(" ( " + filter + " ) "); + } + }); + try { + Expression expression = CCJSqlParserUtil.parseCondExpression(" ( " + joiner.toString() + " ) "); + if (StringUtils.isNotEmpty(joiner.toString())) { + String sql = SqlParserUpdateHelper.addWhere(queryDslReq.getSql(), expression); + log.info("before doRowPermission, queryDslReq:{}", queryDslReq.getSql()); + queryDslReq.setSql(sql); + log.info("after doRowPermission, queryDslReq:{}", queryDslReq.getSql()); + } + } catch (JSQLParserException jsqlParserException) { + log.info("jsqlParser has an exception:{}", jsqlParserException.toString()); + } + + } + + private void doFilterCheckLogic(QueryDslReq queryDslReq, Set resAuthName, + Set sensitiveResReq) { + Set resFilterSet = queryStructUtils.getFilterResNameEnExceptInternalCol(queryDslReq); + Set need2Apply = resFilterSet.stream() + .filter(res -> !resAuthName.contains(res) && sensitiveResReq.contains(res)).collect(Collectors.toSet()); + Set nameCnSet = new HashSet<>(); + + List modelIds = new ArrayList<>(); + modelIds.add(queryDslReq.getModelId()); + List modelInfos = modelService.getModelList(modelIds); + String modelNameCn = Constants.EMPTY; + if (!CollectionUtils.isEmpty(modelInfos)) { + modelNameCn = modelInfos.get(0).getName(); + } + + List dimensionDescList = dimensionService.getDimensions(queryDslReq.getModelId()); + String finalDomainNameCn = modelNameCn; + dimensionDescList.stream().filter(dim -> need2Apply.contains(dim.getBizName())) + .forEach(dim -> nameCnSet.add(finalDomainNameCn + MINUS + dim.getName())); + + if (!CollectionUtils.isEmpty(need2Apply)) { + ModelResp modelResp = modelInfos.get(0); + List admins = modelService.getModelAdmin(modelResp.getId()); + log.info("in doFilterLogic, need2Apply:{}", need2Apply); + String message = String.format("您没有以下维度%s权限, 请联系管理员%s开通", nameCnSet, admins); + throw new InvalidPermissionException(message); + } + } +} diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslPermissionAnnotation.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslPermissionAnnotation.java new file mode 100644 index 000000000..8a9c368dd --- /dev/null +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/DslPermissionAnnotation.java @@ -0,0 +1,14 @@ +package com.tencent.supersonic.semantic.query.utils; + +import java.lang.annotation.Target; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.Documented; + +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface DslPermissionAnnotation { + +} diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java index a13a2907d..65add3e75 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/QueryStructUtils.java @@ -6,11 +6,13 @@ import com.tencent.supersonic.common.pojo.DateConf.DateMode; import com.tencent.supersonic.common.pojo.enums.TypeEnums; import com.tencent.supersonic.common.pojo.Aggregator; import com.tencent.supersonic.common.pojo.DateConf; +import com.tencent.supersonic.common.util.jsqlparser.SqlParserSelectHelper; import com.tencent.supersonic.semantic.api.model.pojo.SchemaItem; import com.tencent.supersonic.semantic.api.model.pojo.ItemDateFilter; import com.tencent.supersonic.semantic.api.model.response.DimensionResp; import com.tencent.supersonic.semantic.api.model.response.ItemDateResp; import com.tencent.supersonic.semantic.api.model.response.MetricResp; +import com.tencent.supersonic.semantic.api.query.request.QueryDslReq; import com.tencent.supersonic.semantic.api.query.request.QueryStructReq; import com.tencent.supersonic.semantic.model.domain.Catalog; @@ -145,11 +147,19 @@ public class QueryStructUtils { sqlFilterUtils.getFiltersCol(queryStructCmd.getOriginalFilter()).stream().forEach(col -> resNameEnSet.add(col)); return resNameEnSet; } - + public Set getResNameEn(QueryDslReq queryDslReq) { + Set resNameEnSet = SqlParserSelectHelper.getAllFields(queryDslReq.getSql()) + .stream().collect(Collectors.toSet()); + return resNameEnSet; + } public Set getResNameEnExceptInternalCol(QueryStructReq queryStructCmd) { Set resNameEnSet = getResNameEn(queryStructCmd); return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); } + public Set getResNameEnExceptInternalCol(QueryDslReq queryDslReq) { + Set resNameEnSet = getResNameEn(queryDslReq); + return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); + } public Set getFilterResNameEn(QueryStructReq queryStructCmd) { Set resNameEnSet = new HashSet<>(); @@ -162,6 +172,12 @@ public class QueryStructUtils { return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); } + public Set getFilterResNameEnExceptInternalCol(QueryDslReq queryDslReq) { + String sql = queryDslReq.getSql(); + Set resNameEnSet = SqlParserSelectHelper.getWhereFields(sql).stream().collect(Collectors.toSet()); + return resNameEnSet.stream().filter(res -> !internalCols.contains(res)).collect(Collectors.toSet()); + } + public String generateInternalMetricName(Long modelId, List groups) { String internalMetricNamePrefix = ""; if (CollectionUtils.isEmpty(groups)) { diff --git a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/SqlFilterUtils.java b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/SqlFilterUtils.java index 40f8caa93..afdfd88a9 100644 --- a/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/SqlFilterUtils.java +++ b/semantic/query/src/main/java/com/tencent/supersonic/semantic/query/utils/SqlFilterUtils.java @@ -5,10 +5,10 @@ import static com.tencent.supersonic.common.pojo.Constants.PARENTHESES_START; import static com.tencent.supersonic.common.pojo.Constants.SPACE; import static com.tencent.supersonic.common.pojo.Constants.SYS_VAR; +import com.tencent.supersonic.common.pojo.Constants; import com.tencent.supersonic.semantic.api.query.enums.FilterOperatorEnum; import com.tencent.supersonic.semantic.api.query.pojo.Criterion; import com.tencent.supersonic.semantic.api.query.pojo.Filter; -import com.tencent.supersonic.common.pojo.Constants; import java.util.ArrayList; import java.util.List; import java.util.Objects; @@ -25,6 +25,7 @@ import org.springframework.util.CollectionUtils; public class SqlFilterUtils { private static String pattern = "^'.*?'$"; + private static String numericPattern = "^[0-9]+$"; public List getFiltersCol(List filters) { List filterCols = new ArrayList<>(); @@ -219,7 +220,7 @@ public class SqlFilterUtils { } private String valueApostropheLogic(String value) { - if (Pattern.matches(pattern, value)) { + if (Pattern.matches(pattern, value) || Pattern.matches(numericPattern, value)) { return value; } return Constants.APOSTROPHE + value + Constants.APOSTROPHE;