[improvement][launcher]Refactor unit tests and demo data. (#1935)

This commit is contained in:
Jun Zhang
2024-12-01 21:08:26 +08:00
committed by GitHub
parent 639d1a78da
commit 0fc29304a8
67 changed files with 2181 additions and 2373 deletions

View File

@@ -17,5 +17,7 @@ public class DbSchema {
private String sql;
private String ddl;
private List<DBColumn> dbColumns;
}

View File

@@ -1,6 +1,7 @@
package com.tencent.supersonic.headless.api.pojo.request;
import com.tencent.supersonic.common.pojo.ChatModelConfig;
import com.tencent.supersonic.headless.api.pojo.DbSchema;
import lombok.Data;
import java.util.List;
@@ -10,12 +11,16 @@ public class ModelBuildReq {
private Long databaseId;
private Long domainId;
private String sql;
private String db;
private List<String> tables;
private List<DbSchema> dbSchemas;
private boolean buildByLLM;
private Integer chatModelId;

View File

@@ -78,7 +78,9 @@ public class PostgresqlAdaptor extends BaseDbAdaptor {
}
return o;
});
return SqlReplaceHelper.replaceFunction(sql, functionMap, functionCall);
sql = SqlReplaceHelper.replaceFunction(sql, functionMap, functionCall);
sql = sql.replaceAll("`", "\"");
return sql;
}
public List<String> getTables(ConnectInfo connectionInfo, String schemaName)

View File

@@ -1,10 +1,15 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
@Data
@TableName("s2_available_date_info")
public class DateInfoDO {
@TableId(type = IdType.AUTO)
private Long id;
private String type;
private Long itemId;

View File

@@ -0,0 +1,52 @@
package com.tencent.supersonic.headless.server.persistence.dataobject;
import com.baomidou.mybatisplus.annotation.IdType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import lombok.Data;
@Data
@TableName("s2_query_stat_info")
public class QueryStatDO {
@TableId(type = IdType.AUTO)
private Long id;
private String traceId;
private Long modelId;
private Long dataSetId;
@TableField("query_user")
private String user;
private String createdAt;
/** corresponding type, such as sql, struct, etc */
private String queryType;
/** NORMAL, PRE_FLUSH */
private Integer queryTypeBack;
private String querySqlCmd;
@TableField("sql_cmd_md5")
private String querySqlCmdMd5;
private String queryStructCmd;
@TableField("struct_cmd_md5")
private String queryStructCmdMd5;
private String sql;
private String sqlMd5;
private String queryEngine;
// private Long startTime;
private Long elapsedMs;
private String queryState;
private Boolean nativeQuery;
private String startDate;
private String endDate;
private String dimensions;
private String metrics;
private String selectCols;
private String aggCols;
private String filterCols;
private String groupByCols;
private String orderByCols;
private Boolean useResultCache;
private Boolean useSqlCache;
private String sqlCacheKey;
private String resultCacheKey;
private String queryOptMode;
}

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.tencent.supersonic.headless.api.pojo.ItemDateFilter;
import com.tencent.supersonic.headless.server.persistence.dataobject.DateInfoDO;
import org.apache.ibatis.annotations.Mapper;
@@ -7,9 +8,7 @@ import org.apache.ibatis.annotations.Mapper;
import java.util.List;
@Mapper
public interface DateInfoMapper {
Boolean upsertDateInfo(DateInfoDO dateInfoDO);
public interface DateInfoMapper extends BaseMapper<DateInfoDO> {
List<DateInfoDO> getDateInfos(ItemDateFilter itemDateFilter);
}

View File

@@ -1,15 +1,15 @@
package com.tencent.supersonic.headless.server.persistence.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.tencent.supersonic.headless.api.pojo.QueryStat;
import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq;
import com.tencent.supersonic.headless.server.persistence.dataobject.QueryStatDO;
import org.apache.ibatis.annotations.Mapper;
import java.util.List;
@Mapper
public interface StatMapper {
Boolean createRecord(QueryStat queryStatInfo);
public interface StatMapper extends BaseMapper<QueryStatDO> {
List<QueryStat> getStatInfo(ItemUseReq itemUseCommend);
}

View File

@@ -15,4 +15,6 @@ public interface DomainRepository {
List<DomainDO> getDomainList();
DomainDO getDomainById(Long id);
List<DomainDO> getDomainByBizName(String bizName);
}

View File

@@ -12,5 +12,4 @@ public interface StatRepository {
List<ItemUseResp> getStatInfo(ItemUseReq itemUseCommend);
List<QueryStat> getQueryStatInfoWithoutCache(ItemUseReq itemUseCommend);
}

View File

@@ -66,7 +66,7 @@ public class DateInfoRepositoryImpl implements DateInfoRepository {
private Integer batchUpsert(List<DateInfoDO> dateInfoDOList) {
Stopwatch stopwatch = Stopwatch.createStarted();
for (DateInfoDO dateInfoDO : dateInfoDOList) {
dateInfoMapper.upsertDateInfo(dateInfoDO);
dateInfoMapper.insertOrUpdate(dateInfoDO);
}
log.info("before final, elapsed time:{}", stopwatch.elapsed(TimeUnit.MILLISECONDS));
return 0;

View File

@@ -1,5 +1,6 @@
package com.tencent.supersonic.headless.server.persistence.repository.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.toolkit.Wrappers;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
import com.tencent.supersonic.headless.server.persistence.mapper.DomainDOMapper;
@@ -43,4 +44,12 @@ public class DomainRepositoryImpl implements DomainRepository {
public DomainDO getDomainById(Long id) {
return domainDOMapper.selectById(id);
}
@Override
public List<DomainDO> getDomainByBizName(String bizName) {
QueryWrapper<DomainDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(DomainDO::getBizName, bizName);
return domainDOMapper.selectList(queryWrapper);
}
}

View File

@@ -6,11 +6,13 @@ import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.QueryStat;
import com.tencent.supersonic.headless.api.pojo.request.ItemUseReq;
import com.tencent.supersonic.headless.api.pojo.response.ItemUseResp;
import com.tencent.supersonic.headless.server.persistence.dataobject.QueryStatDO;
import com.tencent.supersonic.headless.server.persistence.mapper.StatMapper;
import com.tencent.supersonic.headless.server.persistence.repository.StatRepository;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.stereotype.Repository;
import java.util.ArrayList;
@@ -36,7 +38,9 @@ public class StatRepositoryImpl implements StatRepository {
@Override
public Boolean createRecord(QueryStat queryStatInfo) {
return statMapper.createRecord(queryStatInfo);
QueryStatDO queryStatDO = new QueryStatDO();
BeanUtils.copyProperties(queryStatInfo, queryStatDO);
return statMapper.insertOrUpdate(queryStatDO);
}
@Override
@@ -66,11 +70,6 @@ public class StatRepositoryImpl implements StatRepository {
.collect(Collectors.toList());
}
@Override
public List<QueryStat> getQueryStatInfoWithoutCache(ItemUseReq itemUseCommend) {
return statMapper.getStatInfo(itemUseCommend);
}
private void updateStatMapInfo(Map<String, Long> map, String dimensions, String type,
Long dataSetId) {
if (StringUtils.isNotEmpty(dimensions)) {
@@ -92,14 +91,4 @@ public class StatRepositoryImpl implements StatRepository {
}
}
private void updateStatMapInfo(Map<String, Long> map, Long modelId, String type) {
if (Objects.nonNull(modelId)) {
String key = type + AT_SYMBOL + AT_SYMBOL + modelId;
if (map.containsKey(key)) {
map.put(key, map.get(key) + 1);
} else {
map.put(key, 1L);
}
}
}
}

View File

@@ -50,6 +50,14 @@ public class ModelController {
return true;
}
@PostMapping("/createModelBatch")
public Boolean createModelBatch(@RequestBody ModelBuildReq modelBuildReq,
HttpServletRequest request, HttpServletResponse response) throws Exception {
User user = UserHolder.findUser(request, response);
modelService.createModel(modelBuildReq, user);
return true;
}
@PostMapping("/updateModel")
public Boolean updateModel(@RequestBody ModelReq modelReq, HttpServletRequest request,
HttpServletResponse response) throws Exception {

View File

@@ -2,6 +2,7 @@ package com.tencent.supersonic.headless.server.service;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
@@ -17,6 +18,8 @@ public interface DatabaseService {
SemanticQueryResp executeSql(String sql, DatabaseResp databaseResp);
List<DatabaseResp> getDatabaseByType(DataType dataType);
SemanticQueryResp executeSql(SqlExecuteReq sqlExecuteReq, Long id, User user);
DatabaseResp getDatabase(Long id, User user);

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.DomainUpdateReq;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
import java.util.List;
import java.util.Map;
@@ -32,5 +33,7 @@ public interface DomainService {
Set<DomainResp> getDomainAuthSet(User user, AuthType authTypeEnum);
List<DomainDO> getDomainByBizName(String bizName);
Set<DomainResp> getDomainChildren(List<Long> domainId);
}

View File

@@ -23,6 +23,8 @@ public interface ModelService {
ModelResp createModel(ModelReq datasourceReq, User user) throws Exception;
List<ModelResp> createModel(ModelBuildReq modelBuildReq, User user) throws Exception;
ModelResp updateModel(ModelReq datasourceReq, User user) throws Exception;
List<ModelResp> getModelList(MetaFilter metaFilter);

View File

@@ -1,11 +1,13 @@
package com.tencent.supersonic.headless.server.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.impl.ServiceImpl;
import com.google.common.collect.Lists;
import com.tencent.supersonic.common.pojo.QueryColumn;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.EngineType;
import com.tencent.supersonic.headless.api.pojo.DBColumn;
import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.request.DatabaseReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.SqlExecuteReq;
@@ -131,6 +133,15 @@ public class DatabaseServiceImpl extends ServiceImpl<DatabaseDOMapper, DatabaseD
return databaseResp;
}
@Override
public List<DatabaseResp> getDatabaseByType(DataType dataType) {
QueryWrapper<DatabaseDO> queryWrapper = new QueryWrapper<>();
queryWrapper.lambda().eq(DatabaseDO::getType, dataType.getFeature());
List<DatabaseDO> list = list(queryWrapper);
return list.stream().map(DatabaseConverter::convertWithPassword)
.collect(Collectors.toList());
}
@Override
public SemanticQueryResp executeSql(SqlExecuteReq sqlExecuteReq, Long id, User user) {
DatabaseResp databaseResp = getDatabase(id);

View File

@@ -179,6 +179,11 @@ public class DomainServiceImpl implements DomainService {
.collect(Collectors.toMap(DomainResp::getId, a -> a, (k1, k2) -> k1));
}
@Override
public List<DomainDO> getDomainByBizName(String bizName) {
return domainRepository.getDomainByBizName(bizName);
}
@Override
public Set<DomainResp> getDomainChildren(List<Long> domainIds) {
Set<DomainResp> childDomains = new HashSet<>();

View File

@@ -125,6 +125,19 @@ public class ModelServiceImpl implements ModelService {
return ModelConverter.convert(modelDO);
}
@Override
public List<ModelResp> createModel(ModelBuildReq modelBuildReq, User user) throws Exception {
List<ModelResp> modelResps = Lists.newArrayList();
Map<String, ModelSchema> modelSchemaMap = buildModelSchema(modelBuildReq);
for (Map.Entry<String, ModelSchema> entry : modelSchemaMap.entrySet()) {
ModelReq modelReq =
ModelConverter.convert(entry.getValue(), modelBuildReq, entry.getKey());
ModelResp modelResp = createModel(modelReq, user);
modelResps.add(modelResp);
}
return modelResps;
}
@Override
@Transactional
public ModelResp updateModel(ModelReq modelReq, User user) throws Exception {
@@ -231,6 +244,9 @@ public class ModelServiceImpl implements ModelService {
}
private List<DbSchema> getDbSchemes(ModelBuildReq modelBuildReq) throws SQLException {
if (!CollectionUtils.isEmpty(modelBuildReq.getDbSchemas())) {
return modelBuildReq.getDbSchemas();
}
Map<String, List<DBColumn>> dbColumnMap = databaseService.getDbColumns(modelBuildReq);
return convert(dbColumnMap, modelBuildReq);
}

View File

@@ -7,6 +7,7 @@ import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.StatusEnum;
import com.tencent.supersonic.common.util.BeanMapper;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.ColumnSchema;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.Identify;
@@ -14,11 +15,16 @@ import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MeasureParam;
import com.tencent.supersonic.headless.api.pojo.MetricDefineByMeasureParams;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.ModelSchema;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.FieldType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.enums.MetricDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.ModelDefineType;
import com.tencent.supersonic.headless.api.pojo.enums.SemanticType;
import com.tencent.supersonic.headless.api.pojo.request.DimensionReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelBuildReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.MeasureResp;
@@ -156,6 +162,49 @@ public class ModelConverter {
return dimensionReq;
}
public static ModelReq convert(ModelSchema modelSchema, ModelBuildReq modelBuildReq,
String tableName) {
ModelReq modelReq = new ModelReq();
modelReq.setName(modelSchema.getName());
modelReq.setBizName(modelSchema.getBizName());
modelReq.setDatabaseId(modelBuildReq.getDatabaseId());
modelReq.setDomainId(modelBuildReq.getDomainId());
ModelDetail modelDetail = new ModelDetail();
if (StringUtils.isNotBlank(modelBuildReq.getSql())) {
modelDetail.setQueryType(ModelDefineType.SQL_QUERY.getName());
modelDetail.setSqlQuery(modelBuildReq.getSql());
} else {
modelDetail.setQueryType(ModelDefineType.TABLE_QUERY.getName());
modelDetail.setTableQuery(String.format("%s.%s", modelBuildReq.getDb(), tableName));
}
for (ColumnSchema columnSchema : modelSchema.getColumnSchemas()) {
FieldType fieldType = columnSchema.getFiledType();
if (getIdentifyType(fieldType) != null) {
Identify identify = new Identify(columnSchema.getName(),
getIdentifyType(fieldType).name(), columnSchema.getColumnName(), 1);
modelDetail.getIdentifiers().add(identify);
} else if (FieldType.measure.equals(fieldType)) {
Measure measure = new Measure(columnSchema.getName(), columnSchema.getColumnName(),
columnSchema.getAgg().getOperator(), 1);
modelDetail.getMeasures().add(measure);
} else {
Dim dim = new Dim(columnSchema.getName(), columnSchema.getColumnName(),
DimensionType.valueOf(columnSchema.getFiledType().name()), 1);
modelDetail.getDimensions().add(dim);
}
}
modelReq.setModelDetail(modelDetail);
return modelReq;
}
private static IdentifyType getIdentifyType(FieldType fieldType) {
if (FieldType.foreign_key.equals(fieldType) || FieldType.primary_key.equals(fieldType)) {
return IdentifyType.primary;
} else {
return IdentifyType.foreign;
}
}
public static List<ModelResp> convertList(List<ModelDO> modelDOS) {
List<ModelResp> modelDescs = Lists.newArrayList();
if (!CollectionUtils.isEmpty(modelDOS)) {

View File

@@ -40,21 +40,6 @@
<result column="query_opt_mode" property="queryOptMode"/>
</resultMap>
<insert id="createRecord">
insert into s2_query_stat_info
(
trace_id, model_id, data_set_id, `user`, query_type, query_type_back, query_sql_cmd, sql_cmd_md5, query_struct_cmd, struct_cmd_md5, `sql`, sql_md5, query_engine,
elapsed_ms, query_state, native_query, start_date, end_date, dimensions, metrics, select_cols, agg_cols, filter_cols, group_by_cols,
order_by_cols, use_result_cache, use_sql_cache, sql_cache_key, result_cache_key, query_opt_mode
)
values
(
#{traceId}, #{modelId}, #{dataSetId}, #{user}, #{queryType}, #{queryTypeBack}, #{querySqlCmd}, #{querySqlCmdMd5}, #{queryStructCmd}, #{queryStructCmdMd5}, #{sql}, #{sqlMd5}, #{queryEngine},
#{elapsedMs}, #{queryState}, #{nativeQuery}, #{startDate}, #{endDate}, #{dimensions}, #{metrics}, #{selectCols}, #{aggCols}, #{filterCols}, #{groupByCols},
#{orderByCols}, #{useResultCache}, #{useSqlCache}, #{sqlCacheKey}, #{resultCacheKey}, #{queryOptMode}
)
</insert>
<select id="getStatInfo"
resultType="com.tencent.supersonic.headless.api.pojo.QueryStat">
select *

View File

@@ -17,22 +17,6 @@
<result column="date_period" jdbcType="VARCHAR" property="datePeriod"/>
</resultMap>
<insert id="upsertDateInfo">
insert into s2_available_date_info
(`type`, item_id, date_format, start_date, end_date, unavailable_date, created_by,
updated_by,date_period)
values (#{type}, #{itemId}, #{dateFormat}, #{startDate}, #{endDate}, #{unavailableDateList},
#{createdBy}, #{updatedBy}, #{datePeriod}) ON DUPLICATE KEY
UPDATE
date_format = #{dateFormat},
start_date = #{startDate},
end_date = #{endDate},
unavailable_date = #{unavailableDateList},
created_by = #{createdBy},
updated_by = #{updatedBy},
date_period = #{datePeriod}
</insert>
<select id="getDateInfos" resultMap="BaseResultMap">
select e.*
from s2_available_date_info e