(improvement)(build) Add spotless during the build process. (#1639)

This commit is contained in:
lexluo09
2024-09-07 00:36:17 +08:00
committed by GitHub
parent ee15a88b06
commit 5f59e89eea
986 changed files with 15609 additions and 12706 deletions

View File

@@ -8,7 +8,8 @@ import org.springframework.scheduling.annotation.EnableAsync;
import org.springframework.scheduling.annotation.EnableScheduling;
import springfox.documentation.swagger2.annotations.EnableSwagger2;
@SpringBootApplication(scanBasePackages = {"com.tencent.supersonic", "dev.langchain4j"},
@SpringBootApplication(
scanBasePackages = {"com.tencent.supersonic", "dev.langchain4j"},
exclude = {MongoAutoConfiguration.class, MongoDataAutoConfiguration.class})
@EnableScheduling
@EnableAsync

View File

@@ -22,56 +22,39 @@ import springfox.documentation.swagger2.annotations.EnableSwagger2;
@EnableOpenApi
public class SwaggerConfiguration {
/**
* 标题
*/
/** 标题 */
@Value("${swagger.title}")
private String title;
/**
* 基本包
*/
/** 基本包 */
@Value("${swagger.base.package}")
private String basePackage;
/**
* 描述
*/
/** 描述 */
@Value("${swagger.description}")
private String description;
/**
* URL
*/
/** URL */
@Value("${swagger.url}")
private String url;
/**
* 作者
*/
/** 作者 */
@Value("${swagger.contact.name}")
private String contactName;
/**
* 作者网址
*/
/** 作者网址 */
@Value("${swagger.contact.url}")
private String contactUrl;
/**
* 作者邮箱
*/
/** 作者邮箱 */
@Value("${swagger.contact.email}")
private String contactEmail;
/**
* 版本
*/
/** 版本 */
@Value("${swagger.version}")
private String version;
@Autowired
private AuthenticationConfig authenticationConfig;
@Autowired private AuthenticationConfig authenticationConfig;
@Bean
public Docket createRestApi() {
@@ -81,13 +64,15 @@ public class SwaggerConfiguration {
.select()
.apis(RequestHandlerSelectors.basePackage(basePackage))
.paths(PathSelectors.any())
.build().securitySchemes(Lists.newArrayList(apiKey()));
.build()
.securitySchemes(Lists.newArrayList(apiKey()));
}
private ApiKey apiKey() {
return new ApiKey(authenticationConfig.getTokenHttpHeaderKey(),
authenticationConfig.getTokenHttpHeaderKey(), "header");
return new ApiKey(
authenticationConfig.getTokenHttpHeaderKey(),
authenticationConfig.getTokenHttpHeaderKey(),
"header");
}
private ApiInfo apiInfo() {
@@ -99,4 +84,4 @@ public class SwaggerConfiguration {
.version(version)
.build();
}
}
}

View File

@@ -1,5 +1,7 @@
package com.tencent.supersonic.db;
import javax.sql.DataSource;
import com.baomidou.mybatisplus.core.MybatisConfiguration;
import com.baomidou.mybatisplus.extension.spring.MybatisSqlSessionFactoryBean;
import org.apache.ibatis.annotations.Mapper;
@@ -9,9 +11,6 @@ import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import javax.sql.DataSource;
@Configuration
@MapperScan(value = "com.tencent.supersonic", annotationClass = Mapper.class)
public class MybatisConfig {
@@ -26,7 +25,8 @@ public class MybatisConfig {
bean.setConfiguration(configuration);
bean.setDataSource(dataSource);
bean.setMapperLocations(new PathMatchingResourcePatternResolver().getResources(MAPPER_LOCATION));
bean.setMapperLocations(
new PathMatchingResourcePatternResolver().getResources(MAPPER_LOCATION));
return bean.getObject();
}
}

View File

@@ -8,23 +8,23 @@ import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.DefaultDisplayInfo;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
@@ -53,7 +53,7 @@ public class CspiderDemo extends S2BaseDemo {
addModelRela_3(s2Domain, songModelResp, artistModelResp);
addModelRela_4(s2Domain, songModelResp, genreModelResp);
addModelRela_5(s2Domain, songModelResp, filesModelResp);
//batchPushlishMetric();
// batchPushlishMetric();
} catch (Exception e) {
log.error("Failed to add bench mark demo data", e);
}
@@ -149,7 +149,7 @@ public class CspiderDemo extends S2BaseDemo {
List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("持续时间", "duration", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("文件格式", "formats", DimensionType.categorical.name(), 1));
//dimensions.add(new Dim("艺术家名称", "artist_name", DimensionType.categorical.name(), 1));
// dimensions.add(new Dim("艺术家名称", "artist_name", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>();
@@ -160,7 +160,8 @@ public class CspiderDemo extends S2BaseDemo {
modelDetail.setMeasures(Collections.emptyList());
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT f_id, artist_name, file_size, duration, formats FROM files");
modelDetail.setSqlQuery(
"SELECT f_id, artist_name, file_size, duration, formats FROM files");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
}
@@ -187,7 +188,7 @@ public class CspiderDemo extends S2BaseDemo {
identifiers.add(new Identify("歌曲名称", IdentifyType.primary.name(), "song_name"));
identifiers.add(new Identify("歌曲ID", IdentifyType.foreign.name(), "f_id"));
identifiers.add(new Identify("艺术家名称", IdentifyType.foreign.name(), "artist_name"));
//identifiers.add(new Identify("艺术家名称", IdentifyType.foreign.name(), "artist_name"));
// identifiers.add(new Identify("艺术家名称", IdentifyType.foreign.name(), "artist_name"));
modelDetail.setIdentifiers(identifiers);
@@ -197,8 +198,9 @@ public class CspiderDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date, song_name, artist_name, country, f_id, g_name, "
+ " rating, languages, releasedate, resolution FROM song");
modelDetail.setSqlQuery(
"SELECT imp_date, song_name, artist_name, country, f_id, g_name, "
+ " rating, languages, releasedate, resolution FROM song");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
}
@@ -236,7 +238,8 @@ public class CspiderDemo extends S2BaseDemo {
dataSetService.save(dataSetReq, User.getFakeUser());
}
public void addModelRela_1(DomainResp s2Domain, ModelResp genreModelResp, ModelResp artistModelResp) {
public void addModelRela_1(
DomainResp s2Domain, ModelResp genreModelResp, ModelResp artistModelResp) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
@@ -248,9 +251,11 @@ public class CspiderDemo extends S2BaseDemo {
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_2(DomainResp s2Domain, ModelResp filesModelResp, ModelResp artistModelResp) {
public void addModelRela_2(
DomainResp s2Domain, ModelResp filesModelResp, ModelResp artistModelResp) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS));
joinConditions.add(
new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(s2Domain.getId());
modelRelaReq.setFromModelId(filesModelResp.getId());
@@ -260,9 +265,11 @@ public class CspiderDemo extends S2BaseDemo {
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_3(DomainResp s2Domain, ModelResp songModelResp, ModelResp artistModelResp) {
public void addModelRela_3(
DomainResp s2Domain, ModelResp songModelResp, ModelResp artistModelResp) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS));
joinConditions.add(
new JoinCondition("artist_name", "artist_name", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(s2Domain.getId());
modelRelaReq.setFromModelId(songModelResp.getId());
@@ -272,7 +279,8 @@ public class CspiderDemo extends S2BaseDemo {
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_4(DomainResp s2Domain, ModelResp songModelResp, ModelResp genreModelResp) {
public void addModelRela_4(
DomainResp s2Domain, ModelResp songModelResp, ModelResp genreModelResp) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
@@ -284,7 +292,8 @@ public class CspiderDemo extends S2BaseDemo {
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_5(DomainResp s2Domain, ModelResp songModelResp, ModelResp filesModelResp) {
public void addModelRela_5(
DomainResp s2Domain, ModelResp songModelResp, ModelResp filesModelResp) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("f_id", "f_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
@@ -300,5 +309,4 @@ public class CspiderDemo extends S2BaseDemo {
List<Long> ids = Lists.newArrayList(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L);
metricService.batchPublish(ids, User.getFakeUser());
}
}

View File

@@ -10,26 +10,26 @@ import com.tencent.supersonic.chat.server.agent.LLMParserTool;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.util.JsonUtil;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
@@ -60,7 +60,6 @@ public class DuSQLDemo extends S2BaseDemo {
} catch (Exception e) {
log.error("Failed to add bench mark demo data", e);
}
}
@Override
@@ -80,7 +79,7 @@ public class DuSQLDemo extends S2BaseDemo {
domainService.createDomain(domainReq, user);
}
//9
// 9
public void addModel_1() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("公司");
@@ -95,12 +94,14 @@ public class DuSQLDemo extends S2BaseDemo {
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.partition_time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
DimensionTimeTypeParams dimensionTimeTypeParams =
new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
dimensions.add(new Dim("公司名称", "company_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("总部地点", "headquarter_address", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("公司成立时间", "company_established_time", DimensionType.categorical.name(), 1));
dimensions.add(
new Dim("公司成立时间", "company_established_time", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("创始人", "founder", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("首席执行官", "ceo", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
@@ -116,8 +117,9 @@ public class DuSQLDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,company_id,company_name,headquarter_address,"
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
modelDetail.setSqlQuery(
"SELECT imp_date,company_id,company_name,headquarter_address,"
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
}
@@ -137,12 +139,15 @@ public class DuSQLDemo extends S2BaseDemo {
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.partition_time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
DimensionTimeTypeParams dimensionTimeTypeParams =
new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
dimensions.add(new Dim("品牌名称", "brand_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("品牌成立时间", "brand_established_time", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("法定代表人", "legal_representative", DimensionType.categorical.name(), 1));
dimensions.add(
new Dim("品牌成立时间", "brand_established_time", DimensionType.categorical.name(), 1));
dimensions.add(
new Dim("法定代表人", "legal_representative", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>();
@@ -155,8 +160,9 @@ public class DuSQLDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,brand_id,brand_name,brand_established_time,"
+ "company_id,legal_representative,registered_capital FROM brand");
modelDetail.setSqlQuery(
"SELECT imp_date,brand_id,brand_name,brand_established_time,"
+ "company_id,legal_representative,registered_capital FROM brand");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
}
@@ -176,7 +182,8 @@ public class DuSQLDemo extends S2BaseDemo {
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.partition_time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
DimensionTimeTypeParams dimensionTimeTypeParams =
new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
modelDetail.setDimensions(dimensions);
@@ -194,8 +201,9 @@ public class DuSQLDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,company_id,brand_id,revenue_proportion,"
+ "profit_proportion,expenditure_proportion FROM company_revenue");
modelDetail.setSqlQuery(
"SELECT imp_date,company_id,brand_id,revenue_proportion,"
+ "profit_proportion,expenditure_proportion FROM company_revenue");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
MetricResp metricResp = metricService.getMetric(13L, user);
@@ -221,7 +229,8 @@ public class DuSQLDemo extends S2BaseDemo {
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.partition_time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
DimensionTimeTypeParams dimensionTimeTypeParams =
new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
dimensions.add(new Dim("年份", "year_time", DimensionType.categorical.name(), 1));
@@ -234,16 +243,19 @@ public class DuSQLDemo extends S2BaseDemo {
List<Measure> measures = new ArrayList<>();
measures.add(new Measure("营收", "revenue", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("利润", "profit", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("营收同比增长", "revenue_growth_year_on_year", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("利润同比增长", "profit_growth_year_on_year", AggOperatorEnum.SUM.name(), 1));
measures.add(
new Measure(
"营收同比增长", "revenue_growth_year_on_year", AggOperatorEnum.SUM.name(), 1));
measures.add(
new Measure("利润同比增长", "profit_growth_year_on_year", AggOperatorEnum.SUM.name(), 1));
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,year_time,brand_id,revenue,profit,"
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue");
modelDetail.setSqlQuery(
"SELECT imp_date,year_time,brand_id,revenue,profit,"
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
}
public void addDataSet_1() {
@@ -253,11 +265,20 @@ public class DuSQLDemo extends S2BaseDemo {
dataSetReq.setDomainId(4L);
dataSetReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等");
dataSetReq.setAdmins(Lists.newArrayList("admin"));
List<DataSetModelConfig> viewModelConfigs = Lists.newArrayList(
new DataSetModelConfig(9L, Lists.newArrayList(16L, 17L, 18L, 19L, 20L), Lists.newArrayList(10L, 11L)),
new DataSetModelConfig(10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)),
new DataSetModelConfig(11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)),
new DataSetModelConfig(12L, Lists.newArrayList(24L), Lists.newArrayList(16L, 17L, 18L, 19L)));
List<DataSetModelConfig> viewModelConfigs =
Lists.newArrayList(
new DataSetModelConfig(
9L,
Lists.newArrayList(16L, 17L, 18L, 19L, 20L),
Lists.newArrayList(10L, 11L)),
new DataSetModelConfig(
10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)),
new DataSetModelConfig(
11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)),
new DataSetModelConfig(
12L,
Lists.newArrayList(24L),
Lists.newArrayList(16L, 17L, 18L, 19L)));
DataSetDetail dsDetail = new DataSetDetail();
dsDetail.setDataSetModelConfigs(viewModelConfigs);
@@ -276,7 +297,8 @@ public class DuSQLDemo extends S2BaseDemo {
public void addModelRela_1() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
joinConditions.add(
new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(4L);
modelRelaReq.setFromModelId(9L);
@@ -288,7 +310,8 @@ public class DuSQLDemo extends S2BaseDemo {
public void addModelRela_2() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
joinConditions.add(
new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(4L);
modelRelaReq.setFromModelId(9L);
@@ -341,5 +364,4 @@ public class DuSQLDemo extends S2BaseDemo {
log.info("agent:{}", JsonUtil.toString(agent));
agentService.createAgent(agent, User.getFakeUser());
}
}

View File

@@ -97,8 +97,9 @@ public class S2ArtistDemo extends S2BaseDemo {
return domainService.createDomain(domainReq, user);
}
public ModelResp addModel(DomainResp singerDomain,
DatabaseResp s2Database, TagObjectResp singerTagObject) throws Exception {
public ModelResp addModel(
DomainResp singerDomain, DatabaseResp s2Database, TagObjectResp singerTagObject)
throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("歌手库");
modelReq.setBizName("singer");
@@ -118,12 +119,9 @@ public class S2ArtistDemo extends S2BaseDemo {
modelDetail.setIdentifiers(identifiers);
List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("活跃区域", "act_area",
DimensionType.categorical.name(), 1, 1));
dimensions.add(new Dim("代表作", "song_name",
DimensionType.categorical.name(), 1));
dimensions.add(new Dim("流派", "genre",
DimensionType.categorical.name(), 1, 1));
dimensions.add(new Dim("活跃区域", "act_area", DimensionType.categorical.name(), 1, 1));
dimensions.add(new Dim("代表作", "song_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("流派", "genre", DimensionType.categorical.name(), 1, 1));
modelDetail.setDimensions(dimensions);
Measure measure1 = new Measure("播放量", "js_play_cnt", "sum", 1);
@@ -131,23 +129,27 @@ public class S2ArtistDemo extends S2BaseDemo {
Measure measure3 = new Measure("收藏量", "favor_cnt", "sum", 1);
modelDetail.setMeasures(Lists.newArrayList(measure1, measure2, measure3));
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("select singer_name, act_area, song_name, genre, "
+ "js_play_cnt, down_cnt, favor_cnt from singer");
modelDetail.setSqlQuery(
"select singer_name, act_area, song_name, genre, "
+ "js_play_cnt, down_cnt, favor_cnt from singer");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
}
private void addTags(ModelResp model) {
addTag(dimensionService.getDimension("act_area", model.getId()).getId(),
addTag(
dimensionService.getDimension("act_area", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("song_name", model.getId()).getId(),
addTag(
dimensionService.getDimension("song_name", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("genre", model.getId()).getId(),
addTag(
dimensionService.getDimension("genre", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(dimensionService.getDimension("singer_name", model.getId()).getId(),
addTag(
dimensionService.getDimension("singer_name", model.getId()).getId(),
TagDefineType.DIMENSION);
addTag(metricService.getMetric(model.getId(), "js_play_cnt").getId(),
TagDefineType.METRIC);
addTag(metricService.getMetric(model.getId(), "js_play_cnt").getId(), TagDefineType.METRIC);
}
public long addDataSet(DomainResp singerDomain, ModelResp singerModel) {
@@ -209,5 +211,4 @@ public class S2ArtistDemo extends S2BaseDemo {
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
}
}

View File

@@ -11,6 +11,7 @@ import com.tencent.supersonic.common.service.SystemConfigService;
import com.tencent.supersonic.common.util.AESEncryptionUtil;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.DrillDownDimension;
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.api.pojo.RelateDimension;
import com.tencent.supersonic.headless.api.pojo.enums.DataType;
import com.tencent.supersonic.headless.api.pojo.enums.TagDefineType;
@@ -20,7 +21,6 @@ import com.tencent.supersonic.headless.api.pojo.response.DatabaseResp;
import com.tencent.supersonic.headless.api.pojo.response.DimensionResp;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.MetaFilter;
import com.tencent.supersonic.headless.server.service.CanvasService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DatabaseService;
@@ -49,46 +49,29 @@ public abstract class S2BaseDemo implements CommandLineRunner {
protected DatabaseResp demoDatabaseResp;
protected User user = User.getFakeUser();
@Autowired
protected DatabaseService databaseService;
@Autowired
protected DomainService domainService;
@Autowired
protected ModelService modelService;
@Autowired
protected ModelRelaService modelRelaService;
@Autowired
protected DimensionService dimensionService;
@Autowired
protected MetricService metricService;
@Autowired
protected TagMetaService tagMetaService;
@Autowired
protected AuthService authService;
@Autowired
protected DataSetService dataSetService;
@Autowired
protected TermService termService;
@Autowired
protected PluginService pluginService;
@Autowired
protected DataSourceProperties dataSourceProperties;
@Autowired
protected TagObjectService tagObjectService;
@Autowired
protected ChatQueryService chatQueryService;
@Autowired
protected ChatManageService chatManageService;
@Autowired
protected AgentService agentService;
@Autowired
protected SystemConfigService sysParameterService;
@Autowired
protected CanvasService canvasService;
@Autowired
protected DictWordService dictWordService;
@Autowired protected DatabaseService databaseService;
@Autowired protected DomainService domainService;
@Autowired protected ModelService modelService;
@Autowired protected ModelRelaService modelRelaService;
@Autowired protected DimensionService dimensionService;
@Autowired protected MetricService metricService;
@Autowired protected TagMetaService tagMetaService;
@Autowired protected AuthService authService;
@Autowired protected DataSetService dataSetService;
@Autowired protected TermService termService;
@Autowired protected PluginService pluginService;
@Autowired protected DataSourceProperties dataSourceProperties;
@Autowired protected TagObjectService tagObjectService;
@Autowired protected ChatQueryService chatQueryService;
@Autowired protected ChatManageService chatManageService;
@Autowired protected AgentService agentService;
@Autowired protected SystemConfigService sysParameterService;
@Autowired protected CanvasService canvasService;
@Autowired protected DictWordService dictWordService;
@Value("${s2.demo.names:S2VisitsDemo}")
protected List<String> demoList;
@Value("${s2.demo.enableLLM:true}")
protected boolean demoEnableLlm;
@@ -123,7 +106,8 @@ public abstract class S2BaseDemo implements CommandLineRunner {
}
databaseReq.setUrl(url);
databaseReq.setUsername(dataSourceProperties.getUsername());
databaseReq.setPassword(AESEncryptionUtil.aesEncryptECB(dataSourceProperties.getPassword()));
databaseReq.setPassword(
AESEncryptionUtil.aesEncryptECB(dataSourceProperties.getPassword()));
return databaseService.createOrUpdateDatabase(databaseReq, user);
}
@@ -141,11 +125,15 @@ public abstract class S2BaseDemo implements CommandLineRunner {
dataSetModelConfig.setId(modelResp.getId());
MetaFilter metaFilter = new MetaFilter();
metaFilter.setModelIds(Lists.newArrayList(modelResp.getId()));
List<Long> metrics = metricService.getMetrics(metaFilter)
.stream().map(MetricResp::getId).collect(Collectors.toList());
List<Long> metrics =
metricService.getMetrics(metaFilter).stream()
.map(MetricResp::getId)
.collect(Collectors.toList());
dataSetModelConfig.setMetrics(metrics);
List<Long> dimensions = dimensionService.getDimensions(metaFilter)
.stream().map(DimensionResp::getId).collect(Collectors.toList());
List<Long> dimensions =
dimensionService.getDimensions(metaFilter).stream()
.map(DimensionResp::getId)
.collect(Collectors.toList());
dataSetModelConfig.setMetrics(metrics);
dataSetModelConfig.setDimensions(dimensions);
dataSetModelConfigs.add(dataSetModelConfig);
@@ -175,5 +163,4 @@ public abstract class S2BaseDemo implements CommandLineRunner {
protected void updateQueryScore(Integer queryId) {
chatManageService.updateFeedback(queryId, 5, "");
}
}

View File

@@ -86,7 +86,7 @@ public class S2VisitsDemo extends S2BaseDemo {
addModelRela_2(s2Domain, userModel, stayTimeModel);
addTags(userModel);
//create metrics and dimensions
// create metrics and dimensions
DimensionResp departmentDimension = getDimension("department", userModel);
MetricResp metricUv = addMetric_uv(pvUvModel, departmentDimension);
MetricResp metricPv = getMetric("pv", pvUvModel);
@@ -98,21 +98,21 @@ public class S2VisitsDemo extends S2BaseDemo {
updateMetric(stayTimeModel, departmentDimension, userDimension);
updateMetric_pv(pvUvModel, departmentDimension, userDimension, metricPv);
//create data set
// create data set
DataSetResp s2DataSet = addDataSet(s2Domain);
addAuthGroup_1(stayTimeModel);
addAuthGroup_2(stayTimeModel);
//create terms and plugin
// create terms and plugin
addTerm(s2Domain);
addTerm_1(s2Domain);
addPlugin(s2DataSet);
addPlugin_1();
//load dict word
// load dict word
loadDictWord();
//create agent
// create agent
Integer agentId = addAgent(s2DataSet.getId());
addSampleChats(agentId);
updateQueryScore(1);
@@ -150,12 +150,13 @@ public class S2VisitsDemo extends S2BaseDemo {
agent.setDescription("帮助您用自然语言查询指标,支持时间限定、条件筛选、下钻维度以及聚合统计");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList(
"超音数访问次数",
"近15天超音数访问次数汇总",
"按部门统计超音数访问人数",
"对比alice和lucy的停留时长",
"超音数访问次数最高的部门"));
agent.setExamples(
Lists.newArrayList(
"超音数访问次数",
"近15天超音数访问次数汇总",
"按部门统计超音数的访问人数",
"对比alice和lucy的停留时长",
"超音数访问次数最高的部门"));
AgentConfig agentConfig = new AgentConfig();
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
@@ -188,8 +189,9 @@ public class S2VisitsDemo extends S2BaseDemo {
return domainService.createDomain(domainReq, user);
}
public ModelResp addModel_1(DomainResp s2Domain, DatabaseResp s2Database,
TagObjectResp s2TagObject) throws Exception {
public ModelResp addModel_1(
DomainResp s2Domain, DatabaseResp s2Database, TagObjectResp s2TagObject)
throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("用户部门");
modelReq.setBizName("user_department");
@@ -207,8 +209,7 @@ public class S2VisitsDemo extends S2BaseDemo {
modelDetail.setIdentifiers(identifiers);
List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("部门", "department",
DimensionType.categorical.name(), 1));
dimensions.add(new Dim("部门", "department", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
List<Field> fields = Lists.newArrayList();
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());
@@ -258,8 +259,9 @@ public class S2VisitsDemo extends S2BaseDemo {
fields.add(Field.builder().fieldName("pv").dataType("Long").build());
fields.add(Field.builder().fieldName("user_id").dataType("Varchar").build());
modelDetail.setFields(fields);
modelDetail.setSqlQuery("SELECT imp_date, user_name, page, 1 as pv, "
+ "user_name as user_id FROM s2_pv_uv_statis");
modelDetail.setSqlQuery(
"SELECT imp_date, user_name, page, 1 as pv, "
+ "user_name as user_id FROM s2_pv_uv_statis");
modelDetail.setQueryType("sql_query");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
@@ -300,13 +302,15 @@ public class S2VisitsDemo extends S2BaseDemo {
fields.add(Field.builder().fieldName("page").dataType("Varchar").build());
fields.add(Field.builder().fieldName("stay_hours").dataType("Double").build());
modelDetail.setFields(fields);
modelDetail.setSqlQuery("select imp_date,user_name,stay_hours,page from s2_stay_time_statis");
modelDetail.setSqlQuery(
"select imp_date,user_name,stay_hours,page from s2_stay_time_statis");
modelDetail.setQueryType("sql_query");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, user);
}
public void addModelRela_1(DomainResp s2Domain, ModelResp userDepartmentModel, ModelResp pvUvModel) {
public void addModelRela_1(
DomainResp s2Domain, ModelResp userDepartmentModel, ModelResp pvUvModel) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("user_name", "user_name", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
@@ -318,7 +322,8 @@ public class S2VisitsDemo extends S2BaseDemo {
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_2(DomainResp s2Domain, ModelResp userDepartmentModel, ModelResp stayTimeModel) {
public void addModelRela_2(
DomainResp s2Domain, ModelResp userDepartmentModel, ModelResp stayTimeModel) {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("user_name", "user_name", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
@@ -331,11 +336,13 @@ public class S2VisitsDemo extends S2BaseDemo {
}
private void addTags(ModelResp model) {
addTag(dimensionService.getDimension("department", model.getId()).getId(),
addTag(
dimensionService.getDimension("department", model.getId()).getId(),
TagDefineType.DIMENSION);
}
public void updateDimension(ModelResp stayTimeModel, DimensionResp pageDimension) throws Exception {
public void updateDimension(ModelResp stayTimeModel, DimensionResp pageDimension)
throws Exception {
DimensionReq dimensionReq = new DimensionReq();
dimensionReq.setType(DimensionType.categorical.name());
dimensionReq.setId(pageDimension.getId());
@@ -351,10 +358,10 @@ public class S2VisitsDemo extends S2BaseDemo {
dimensionService.updateDimension(dimensionReq, user);
}
public void updateMetric(ModelResp stayTimeModel, DimensionResp departmentDimension,
DimensionResp userDimension) throws Exception {
MetricResp stayHoursMetric =
metricService.getMetric(stayTimeModel.getId(), "stay_hours");
public void updateMetric(
ModelResp stayTimeModel, DimensionResp departmentDimension, DimensionResp userDimension)
throws Exception {
MetricResp stayHoursMetric = metricService.getMetric(stayTimeModel.getId(), "stay_hours");
MetricReq metricReq = new MetricReq();
metricReq.setModelId(stayTimeModel.getId());
metricReq.setId(stayHoursMetric.getId());
@@ -366,19 +373,25 @@ public class S2VisitsDemo extends S2BaseDemo {
MetricDefineByMeasureParams metricTypeParams = new MetricDefineByMeasureParams();
metricTypeParams.setExpr("s2_stay_time_statis_stay_hours");
List<MeasureParam> measures = new ArrayList<>();
MeasureParam measure = new MeasureParam("s2_stay_time_statis_stay_hours",
"", AggOperatorEnum.SUM.getOperator());
MeasureParam measure =
new MeasureParam(
"s2_stay_time_statis_stay_hours", "", AggOperatorEnum.SUM.getOperator());
measures.add(measure);
metricTypeParams.setMeasures(measures);
metricReq.setMetricDefineByMeasureParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.MEASURE);
metricReq.setRelateDimension(getRelateDimension(
Lists.newArrayList(departmentDimension.getId(), userDimension.getId())));
metricReq.setRelateDimension(
getRelateDimension(
Lists.newArrayList(departmentDimension.getId(), userDimension.getId())));
metricService.updateMetric(metricReq, user);
}
public void updateMetric_pv(ModelResp pvUvModel, DimensionResp departmentDimension,
DimensionResp userDimension, MetricResp metricPv) throws Exception {
public void updateMetric_pv(
ModelResp pvUvModel,
DimensionResp departmentDimension,
DimensionResp userDimension,
MetricResp metricPv)
throws Exception {
MetricReq metricReq = new MetricReq();
metricReq.setModelId(pvUvModel.getId());
metricReq.setId(metricPv.getId());
@@ -388,18 +401,20 @@ public class S2VisitsDemo extends S2BaseDemo {
MetricDefineByMeasureParams metricTypeParams = new MetricDefineByMeasureParams();
metricTypeParams.setExpr("s2_pv_uv_statis_pv");
List<MeasureParam> measures = new ArrayList<>();
MeasureParam measure = new MeasureParam("s2_pv_uv_statis_pv",
"", AggOperatorEnum.SUM.getOperator());
MeasureParam measure =
new MeasureParam("s2_pv_uv_statis_pv", "", AggOperatorEnum.SUM.getOperator());
measures.add(measure);
metricTypeParams.setMeasures(measures);
metricReq.setMetricDefineByMeasureParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.MEASURE);
metricReq.setRelateDimension(getRelateDimension(
Lists.newArrayList(departmentDimension.getId(), userDimension.getId())));
metricReq.setRelateDimension(
getRelateDimension(
Lists.newArrayList(departmentDimension.getId(), userDimension.getId())));
metricService.updateMetric(metricReq, user);
}
public MetricResp addMetric_uv(ModelResp uvModel, DimensionResp departmentDimension) throws Exception {
public MetricResp addMetric_uv(ModelResp uvModel, DimensionResp departmentDimension)
throws Exception {
MetricReq metricReq = new MetricReq();
metricReq.setModelId(uvModel.getId());
metricReq.setName("访问用户数");
@@ -414,13 +429,17 @@ public class S2VisitsDemo extends S2BaseDemo {
metricTypeParams.setFields(fieldParams);
metricReq.setMetricDefineByFieldParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.FIELD);
metricReq.setRelateDimension(getRelateDimension(
Lists.newArrayList(departmentDimension.getId())));
metricReq.setRelateDimension(
getRelateDimension(Lists.newArrayList(departmentDimension.getId())));
return metricService.createMetric(metricReq, user);
}
public MetricResp addMetric_pv_avg(MetricResp metricPv, MetricResp metricUv,
DimensionResp departmentDimension, ModelResp pvModel) throws Exception {
public MetricResp addMetric_pv_avg(
MetricResp metricPv,
MetricResp metricUv,
DimensionResp departmentDimension,
ModelResp pvModel)
throws Exception {
MetricReq metricReq = new MetricReq();
metricReq.setModelId(pvModel.getId());
metricReq.setName("人均访问次数");
@@ -439,7 +458,8 @@ public class S2VisitsDemo extends S2BaseDemo {
metricTypeParams.setMetrics(metrics);
metricReq.setMetricDefineByMetricParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.METRIC);
metricReq.setRelateDimension(getRelateDimension(Lists.newArrayList(departmentDimension.getId())));
metricReq.setRelateDimension(
getRelateDimension(Lists.newArrayList(departmentDimension.getId())));
return metricService.createMetric(metricReq, user);
}
@@ -555,5 +575,4 @@ public class S2VisitsDemo extends S2BaseDemo {
private void loadDictWord() {
dictWordService.loadDictWord();
}
}

View File

@@ -5,7 +5,6 @@ import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentConfig;
import com.tencent.supersonic.chat.server.agent.MultiTurnConfig;
import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order;
@@ -27,8 +26,7 @@ public class SmallTalkDemo extends S2BaseDemo {
agent.setEnableSearch(0);
AgentConfig agentConfig = new AgentConfig();
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agent.setExamples(Lists.newArrayList("如何才能变帅",
"如何才能赚更多钱", "如何才能世界和平"));
agent.setExamples(Lists.newArrayList("如何才能变帅", "如何才能赚更多钱", "如何才能世界和平"));
MultiTurnConfig multiTurnConfig = new MultiTurnConfig();
multiTurnConfig.setEnableMultiTurn(true);
agent.setMultiTurnConfig(multiTurnConfig);
@@ -38,9 +36,8 @@ public class SmallTalkDemo extends S2BaseDemo {
@Override
boolean checkNeedToRun() {
List<String> agentNames = agentService.getAgents()
.stream().map(Agent::getName).collect(Collectors.toList());
List<String> agentNames =
agentService.getAgents().stream().map(Agent::getName).collect(Collectors.toList());
return !agentNames.contains("来闲聊");
}
}

View File

@@ -3,6 +3,4 @@ package com.tencent.supersonic;
import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest(classes = {StandaloneLauncher.class})
public class BaseApplication {
}
public class BaseApplication {}

View File

@@ -3,12 +3,12 @@ package com.tencent.supersonic.chat;
import com.tencent.supersonic.BaseApplication;
import com.tencent.supersonic.chat.api.pojo.request.ChatExecuteReq;
import com.tencent.supersonic.chat.api.pojo.request.ChatParseReq;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.service.AgentService;
import com.tencent.supersonic.chat.server.service.ChatQueryService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -26,23 +26,23 @@ public class BaseTest extends BaseApplication {
protected final String endDay = LocalDate.now().plusDays(-1).toString();
protected final String period = "DAY";
@Autowired
protected ChatQueryService chatQueryService;
@Autowired
protected AgentService agentService;
@Autowired protected ChatQueryService chatQueryService;
@Autowired protected AgentService agentService;
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId) throws Exception {
protected QueryResult submitMultiTurnChat(String queryText, Integer agentId, Integer chatId)
throws Exception {
ParseResp parseResp = submitParse(queryText, agentId, chatId);
SemanticParseInfo semanticParseInfo = parseResp.getSelectedParses().get(0);
ChatExecuteReq request = ChatExecuteReq.builder()
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(semanticParseInfo.getId())
.queryId(parseResp.getQueryId())
.chatId(chatId)
.saveAnswer(true)
.build();
ChatExecuteReq request =
ChatExecuteReq.builder()
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(semanticParseInfo.getId())
.queryId(parseResp.getQueryId())
.chatId(chatId)
.saveAnswer(true)
.build();
QueryResult queryResult = chatQueryService.performExecution(request);
queryResult.setChatContext(semanticParseInfo);
return queryResult;
@@ -53,15 +53,16 @@ public class BaseTest extends BaseApplication {
ParseResp parseResp = submitParse(queryText, agentId, chatId);
SemanticParseInfo parseInfo = parseResp.getSelectedParses().get(0);
ChatExecuteReq request = ChatExecuteReq.builder()
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(parseInfo.getId())
.agentId(agentId)
.chatId(chatId)
.queryId(parseResp.getQueryId())
.saveAnswer(false)
.build();
ChatExecuteReq request =
ChatExecuteReq.builder()
.queryText(parseResp.getQueryText())
.user(DataUtils.getUser())
.parseId(parseInfo.getId())
.agentId(agentId)
.chatId(chatId)
.queryId(parseResp.getQueryId())
.saveAnswer(false)
.build();
QueryResult result = chatQueryService.performExecution(request);
result.setChatContext(parseInfo);
@@ -75,10 +76,16 @@ public class BaseTest extends BaseApplication {
}
protected void assertSchemaElements(Set<SchemaElement> expected, Set<SchemaElement> actual) {
Set<String> expectedNames = expected.stream().map(s -> s.getName())
.filter(s -> s != null).collect(Collectors.toSet());
Set<String> actualNames = actual.stream().map(s -> s.getName())
.filter(s -> s != null).collect(Collectors.toSet());
Set<String> expectedNames =
expected.stream()
.map(s -> s.getName())
.filter(s -> s != null)
.collect(Collectors.toSet());
Set<String> actualNames =
actual.stream()
.map(s -> s.getName())
.filter(s -> s != null)
.collect(Collectors.toSet());
assertEquals(expectedNames, actualNames);
}
@@ -94,10 +101,10 @@ public class BaseTest extends BaseApplication {
assertSchemaElements(expectedParseInfo.getMetrics(), actualParseInfo.getMetrics());
assertSchemaElements(expectedParseInfo.getDimensions(), actualParseInfo.getDimensions());
assertEquals(expectedParseInfo.getDimensionFilters(), actualParseInfo.getDimensionFilters());
assertEquals(
expectedParseInfo.getDimensionFilters(), actualParseInfo.getDimensionFilters());
assertEquals(expectedParseInfo.getMetricFilters(), actualParseInfo.getMetricFilters());
assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo());
}
}

View File

@@ -32,13 +32,16 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", 8L);
QueryFilter dimensionFilter =
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()));
expectedParseInfo
.getDimensions()
.addAll(
Lists.newArrayList(
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()));
assertQueryResult(expectedResult, actualResult);
}
@@ -55,17 +58,19 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", 8L);
QueryFilter dimensionFilter =
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build());
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
SchemaElement.builder().name("歌手名").build(),
SchemaElement.builder().name("活跃区域").build(),
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()
));
expectedParseInfo
.getDimensions()
.addAll(
Lists.newArrayList(
SchemaElement.builder().name("歌手名").build(),
SchemaElement.builder().name("活跃区域").build(),
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()));
assertQueryResult(expectedResult, actualResult);
}
@@ -82,19 +87,20 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS,
"国风", "流派", 7L);
QueryFilter dimensionFilter =
DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", "流派", 7L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getMetrics().add(SchemaElement.builder().name("播放量").build());
expectedParseInfo.getDimensions().addAll(Lists.newArrayList(
SchemaElement.builder().name("歌手名").build(),
SchemaElement.builder().name("活跃区域").build(),
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()
));
expectedParseInfo
.getDimensions()
.addAll(
Lists.newArrayList(
SchemaElement.builder().name("歌手名").build(),
SchemaElement.builder().name("活跃区域").build(),
SchemaElement.builder().name("流派").build(),
SchemaElement.builder().name("代表作").build()));
assertQueryResult(expectedResult, actualResult);
}
}

View File

@@ -1,11 +1,11 @@
package com.tencent.supersonic.chat;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
@@ -38,16 +38,20 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo
.getDimensionFilters()
.add(
DataUtils.getFilter(
"user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
}
//@Test
// @Test
public void testMetricDomain() throws Exception {
QueryResult actualResult = submitNewChat("超音数总访问次数", DataUtils.metricAgentId);
@@ -58,7 +62,8 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
@@ -78,7 +83,8 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
@@ -99,10 +105,12 @@ public class MetricTest extends BaseTest {
List<String> list = new ArrayList<>();
list.add("alice");
list.add("lucy");
QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L);
QueryFilter dimensionFilter =
DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L);
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
@@ -142,7 +150,8 @@ public class MetricTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
@@ -154,7 +163,8 @@ public class MetricTest extends BaseTest {
DateFormat textFormat = new SimpleDateFormat("yyyy年mm月dd日");
String dateStr = textFormat.format(format.parse(startDay));
QueryResult actualResult = submitNewChat(String.format("想知道%salice的访问次数", dateStr), DataUtils.metricAgentId);
QueryResult actualResult =
submitNewChat(String.format("想知道%salice的访问次数", dateStr), DataUtils.metricAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -164,13 +174,16 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo
.getDimensionFilters()
.add(
DataUtils.getFilter(
"user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
}
}

View File

@@ -1,10 +1,10 @@
package com.tencent.supersonic.chat;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Order;
@@ -17,8 +17,9 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(1)
public void queryTest_01() throws Exception {
QueryResult actualResult = submitMultiTurnChat("alice的访问次数",
DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
QueryResult actualResult =
submitMultiTurnChat(
"alice的访问次数", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -29,10 +30,14 @@ public class MultiTurnsTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo
.getDimensionFilters()
.add(
DataUtils.getFilter(
"user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
@@ -41,8 +46,9 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(2)
public void queryTest_02() throws Exception {
QueryResult actualResult = submitMultiTurnChat("停留时长呢", DataUtils.metricAgentId,
DataUtils.MULTI_TURNS_CHAT_ID);
QueryResult actualResult =
submitMultiTurnChat(
"停留时长呢", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -53,10 +59,14 @@ public class MultiTurnsTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo
.getDimensionFilters()
.add(
DataUtils.getFilter(
"user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
@@ -65,8 +75,9 @@ public class MultiTurnsTest extends BaseTest {
@Test
@Order(3)
public void queryTest_03() throws Exception {
QueryResult actualResult = submitMultiTurnChat("lucy的如何", DataUtils.metricAgentId,
DataUtils.MULTI_TURNS_CHAT_ID);
QueryResult actualResult =
submitMultiTurnChat(
"lucy的如何", DataUtils.metricAgentId, DataUtils.MULTI_TURNS_CHAT_ID);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -77,13 +88,14 @@ public class MultiTurnsTest extends BaseTest {
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长"));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "lucy", "用户", 2L));
expectedParseInfo
.getDimensionFilters()
.add(DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "lucy", "用户", 2L));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.RECENT, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.METRIC);
assertQueryResult(expectedResult, actualResult);
}
}

View File

@@ -95,7 +95,10 @@ public class Text2SQLEval extends BaseTest {
QueryResult result = submitNewChat("过去半个月核心用户的总停留时长", agentId);
assert result.getQueryColumns().size() >= 1;
assert result.getQueryColumns().stream()
.filter(c -> c.getName().contains("停留时长")).collect(Collectors.toList()).size() == 1;
.filter(c -> c.getName().contains("停留时长"))
.collect(Collectors.toList())
.size()
== 1;
assert result.getQueryResults().size() >= 1;
}

View File

@@ -12,9 +12,9 @@ import com.tencent.supersonic.headless.api.pojo.request.QuerySqlReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.request.SemanticQueryReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.headless.server.persistence.dataobject.DomainDO;
import com.tencent.supersonic.headless.server.persistence.repository.DomainRepository;
import com.tencent.supersonic.headless.server.facade.service.SemanticLayerService;
import com.tencent.supersonic.util.DataUtils;
import org.apache.commons.collections.CollectionUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -27,11 +27,9 @@ import static java.time.LocalDate.now;
public class BaseTest extends BaseApplication {
@Autowired
protected SemanticLayerService semanticLayerService;
@Autowired protected SemanticLayerService semanticLayerService;
@Autowired
private DomainRepository domainRepository;
@Autowired private DomainRepository domainRepository;
protected SemanticQueryResp queryBySql(String sql) throws Exception {
return queryBySql(sql, User.getFakeUser());
@@ -81,8 +79,7 @@ public class BaseTest extends BaseApplication {
return queryStructReq;
}
protected QueryStructReq buildQueryStructReq(List<String> groups,
Aggregator aggregator) {
protected QueryStructReq buildQueryStructReq(List<String> groups, Aggregator aggregator) {
QueryStructReq queryStructReq = new QueryStructReq();
for (Long modelId : DataUtils.getMetricAgentIModelIds()) {
queryStructReq.addModelId(modelId);
@@ -108,5 +105,4 @@ public class BaseTest extends BaseApplication {
domainDO.setIsOpen(0);
domainRepository.updateDomain(domainDO);
}
}

View File

@@ -18,13 +18,10 @@ import java.util.Arrays;
import java.util.Date;
import java.util.List;
public class DictTest extends BaseTest {
@Autowired
private DictConfMapper confMapper;
@Autowired private DictConfMapper confMapper;
@Autowired
private DictTaskService taskService;
@Autowired private DictTaskService taskService;
@Test
public void insertConf() {
@@ -83,12 +80,15 @@ public class DictTest extends BaseTest {
void testAddTask() {
editConf();
DictConfDO confDODb = confMapper.selectById(1L);
DictSingleTaskReq dictTask = DictSingleTaskReq.builder().itemId(confDODb.getItemId())
.type(TypeEnums.DIMENSION).build();
DictSingleTaskReq dictTask =
DictSingleTaskReq.builder()
.itemId(confDODb.getItemId())
.type(TypeEnums.DIMENSION)
.build();
taskService.addDictTask(dictTask, null);
DictSingleTaskReq taskReq = DictSingleTaskReq.builder().itemId(3L).type(TypeEnums.DIMENSION).build();
DictSingleTaskReq taskReq =
DictSingleTaskReq.builder().itemId(3L).type(TypeEnums.DIMENSION).build();
taskService.deleteDictTask(taskReq, null);
System.out.println();
}
}
}

View File

@@ -1,8 +1,5 @@
package com.tencent.supersonic.headless;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import com.tencent.supersonic.auth.api.authentication.utils.UserHolder;
import com.tencent.supersonic.auth.authentication.strategy.FakeUserStrategy;
import com.tencent.supersonic.headless.server.task.FlightServerInitTask;
@@ -19,36 +16,37 @@ import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@Slf4j
public class FlightSqlTest extends BaseTest {
@Autowired
private FlightServerInitTask flightSqlListener;
@Autowired
private FakeUserStrategy fakeUserStrategy;
@Autowired private FlightServerInitTask flightSqlListener;
@Autowired private FakeUserStrategy fakeUserStrategy;
@Test
void test01() throws Exception {
startServer();
String host = flightSqlListener.getHost();
Integer port = flightSqlListener.getPort();
FlightSqlClient sqlClient = new FlightSqlClient(
FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port))
.build());
FlightSqlClient sqlClient =
new FlightSqlClient(
FlightClient.builder(
new RootAllocator(Integer.MAX_VALUE),
Location.forGrpcInsecure(host, port))
.build());
CallHeaders headers = new FlightCallHeaders();
headers.insert("dataSetId", "1");
headers.insert("name", "admin");
headers.insert("password", "admin");
HeaderCallOption headerOption = new HeaderCallOption(headers);
try (final FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
headerOption)) {
try (final FlightSqlClient.PreparedStatement preparedStatement =
sqlClient.prepare(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门", headerOption)) {
final FlightInfo info = preparedStatement.execute();
FlightStream stream = sqlClient.getStream(info
.getEndpoints()
.get(0).getTicket());
FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket());
int rowCnt = 0;
int colCnt = 0;
while (stream.next()) {
@@ -69,9 +67,12 @@ public class FlightSqlTest extends BaseTest {
startServer();
String host = flightSqlListener.getHost();
Integer port = flightSqlListener.getPort();
FlightSqlClient sqlClient = new FlightSqlClient(
FlightClient.builder(new RootAllocator(Integer.MAX_VALUE), Location.forGrpcInsecure(host, port))
.build());
FlightSqlClient sqlClient =
new FlightSqlClient(
FlightClient.builder(
new RootAllocator(Integer.MAX_VALUE),
Location.forGrpcInsecure(host, port))
.build());
CallHeaders headers = new FlightCallHeaders();
headers.insert("dataSetId", "1");
@@ -79,12 +80,11 @@ public class FlightSqlTest extends BaseTest {
headers.insert("password", "admin");
HeaderCallOption headerOption = new HeaderCallOption(headers);
try {
FlightInfo flightInfo = sqlClient.execute(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
headerOption);
FlightStream stream = sqlClient.getStream(flightInfo
.getEndpoints()
.get(0).getTicket());
FlightInfo flightInfo =
sqlClient.execute(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门",
headerOption);
FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket());
int rowCnt = 0;
int colCnt = 0;
while (stream.next()) {

View File

@@ -14,8 +14,7 @@ import java.util.Collections;
public class MetaDiscoveryTest extends BaseTest {
@Autowired
protected ChatLayerService chatLayerService;
@Autowired protected ChatLayerService chatLayerService;
@Test
public void testGetMapMeta() throws Exception {

View File

@@ -15,8 +15,7 @@ import java.util.stream.Collectors;
public class ModelSchemaTest extends BaseTest {
@Autowired
private ModelService modelService;
@Autowired private ModelService modelService;
@Test
void testGetUnAvailableItem() {
@@ -25,9 +24,11 @@ public class ModelSchemaTest extends BaseTest {
fieldRemovedReq.setFields(Lists.newArrayList("pv"));
UnAvailableItemResp unAvailableItemResp = modelService.getUnAvailableItem(fieldRemovedReq);
List<Long> expectedUnAvailableMetricId = Lists.newArrayList(1L, 4L);
List<Long> actualUnAvailableMetricId = unAvailableItemResp.getMetricResps()
.stream().map(MetricResp::getId).sorted(Comparator.naturalOrder()).collect(Collectors.toList());
List<Long> actualUnAvailableMetricId =
unAvailableItemResp.getMetricResps().stream()
.map(MetricResp::getId)
.sorted(Comparator.naturalOrder())
.collect(Collectors.toList());
Assertions.assertEquals(expectedUnAvailableMetricId, actualUnAvailableMetricId);
}
}

View File

@@ -15,8 +15,7 @@ import static org.junit.Assert.assertThrows;
public class QueryByMetricTest extends BaseTest {
@Autowired
protected MetricService metricService;
@Autowired protected MetricService metricService;
@Test
public void testWithMetricAndDimensionBizNames() throws Exception {
@@ -51,7 +50,8 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(2L);
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
assertThrows(IllegalArgumentException.class,
assertThrows(
IllegalArgumentException.class,
() -> queryByMetric(queryMetricReq, User.getFakeUser()));
}
@@ -66,7 +66,8 @@ public class QueryByMetricTest extends BaseTest {
Assert.assertEquals(6, queryResp.getResultList().size());
}
private SemanticQueryResp queryByMetric(QueryMetricReq queryMetricReq, User user) throws Exception {
private SemanticQueryResp queryByMetric(QueryMetricReq queryMetricReq, User user)
throws Exception {
QueryStructReq convert = metricService.convert(queryMetricReq);
return semanticLayerService.queryByReq(convert.convert(), user);
}

View File

@@ -17,7 +17,8 @@ public class QueryBySqlTest extends BaseTest {
@Test
public void testDetailQuery() throws Exception {
SemanticQueryResp semanticQueryResp = queryBySql("SELECT 用户,访问次数 FROM 超音数PVUV统计 WHERE 用户='alice' ");
SemanticQueryResp semanticQueryResp =
queryBySql("SELECT 用户,访问次数 FROM 超音数PVUV统计 WHERE 用户='alice' ");
assertEquals(2, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
@@ -29,7 +30,8 @@ public class QueryBySqlTest extends BaseTest {
@Test
public void testSumQuery() throws Exception {
SemanticQueryResp semanticQueryResp = queryBySql("SELECT SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 ");
SemanticQueryResp semanticQueryResp =
queryBySql("SELECT SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 ");
assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
@@ -39,7 +41,8 @@ public class QueryBySqlTest extends BaseTest {
@Test
public void testGroupByQuery() throws Exception {
SemanticQueryResp result = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
SemanticQueryResp result =
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);
@@ -50,8 +53,9 @@ public class QueryBySqlTest extends BaseTest {
@Test
public void testFilterQuery() throws Exception {
SemanticQueryResp result = queryBySql(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 ");
SemanticQueryResp result =
queryBySql(
"SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 WHERE 部门 ='HR' GROUP BY 部门 ");
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);
@@ -76,13 +80,15 @@ public class QueryBySqlTest extends BaseTest {
@Test
public void testCacheQuery() throws Exception {
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
SemanticQueryResp result2 = queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
SemanticQueryResp result2 =
queryBySql("SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ");
assertTrue(result2.isUseCache());
}
@Test
public void testBizNameQuery() throws Exception {
SemanticQueryResp result1 = queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'");
SemanticQueryResp result1 =
queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'");
SemanticQueryResp result2 = queryBySql("SELECT SUM(访问次数) FROM 超音数PVUV统计 WHERE 部门 ='HR'");
assertEquals(1, result1.getColumns().size());
assertEquals(1, result2.getColumns().size());
@@ -94,15 +100,19 @@ public class QueryBySqlTest extends BaseTest {
public void testAuthorization_model() {
User alice = DataUtils.getUserAlice();
setDomainNotOpenToAll();
assertThrows(InvalidPermissionException.class,
assertThrows(
InvalidPermissionException.class,
() -> queryBySql("SELECT SUM(pv) FROM 超音数PVUV统计 WHERE department ='HR'", alice));
}
@Test
public void testAuthorization_sensitive_metric() throws Exception {
User tom = DataUtils.getUserTom();
assertThrows(InvalidPermissionException.class,
() -> queryBySql("SELECT SUM(stay_hours) FROM 停留时长统计 WHERE department ='HR'", tom));
assertThrows(
InvalidPermissionException.class,
() ->
queryBySql(
"SELECT SUM(stay_hours) FROM 停留时长统计 WHERE department ='HR'", tom));
}
@Test
@@ -121,5 +131,4 @@ public class QueryBySqlTest extends BaseTest {
Assertions.assertNotNull(semanticQueryResp.getQueryAuthorization().getMessage());
Assertions.assertTrue(semanticQueryResp.getSql().contains("user_name = 'tom'"));
}
}

View File

@@ -19,9 +19,11 @@ import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;
@@ -44,9 +46,10 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testDetailQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("user_name", "department"),
QueryType.DETAIL);
SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("user_name", "department"), QueryType.DETAIL);
SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
assertEquals(3, semanticQueryResp.getColumns().size());
QueryColumn firstColumn = semanticQueryResp.getColumns().get(0);
assertEquals("用户", firstColumn.getName());
@@ -60,7 +63,8 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testSumQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(null);
SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
SemanticQueryResp semanticQueryResp =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
assertEquals(1, semanticQueryResp.getColumns().size());
QueryColumn queryColumn = semanticQueryResp.getColumns().get(0);
assertEquals("访问次数", queryColumn.getName());
@@ -70,7 +74,8 @@ public class QueryByStructTest extends BaseTest {
@Test
public void testGroupByQuery() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
SemanticQueryResp result = semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);
@@ -91,7 +96,8 @@ public class QueryByStructTest extends BaseTest {
dimensionFilters.add(filter);
queryStructReq.setDimensionFilters(dimensionFilters);
SemanticQueryResp result = semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
SemanticQueryResp result =
semanticLayerService.queryByReq(queryStructReq, User.getFakeUser());
assertEquals(2, result.getColumns().size());
QueryColumn firstColumn = result.getColumns().get(0);
QueryColumn secondColumn = result.getColumns().get(1);
@@ -106,7 +112,8 @@ public class QueryByStructTest extends BaseTest {
User alice = DataUtils.getUserAlice();
setDomainNotOpenToAll();
QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("department"));
assertThrows(InvalidPermissionException.class,
assertThrows(
InvalidPermissionException.class,
() -> semanticLayerService.queryByReq(queryStructReq1, alice));
}
@@ -116,8 +123,10 @@ public class QueryByStructTest extends BaseTest {
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("stay_hours");
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"), aggregator);
assertThrows(InvalidPermissionException.class,
QueryStructReq queryStructReq =
buildQueryStructReq(Arrays.asList("department"), aggregator);
assertThrows(
InvalidPermissionException.class,
() -> semanticLayerService.queryByReq(queryStructReq, tom));
}
@@ -127,10 +136,10 @@ public class QueryByStructTest extends BaseTest {
Aggregator aggregator = new Aggregator();
aggregator.setFunc(AggOperatorEnum.SUM);
aggregator.setColumn("pv");
QueryStructReq queryStructReq1 = buildQueryStructReq(Arrays.asList("department"), aggregator);
QueryStructReq queryStructReq1 =
buildQueryStructReq(Arrays.asList("department"), aggregator);
SemanticQueryResp semanticQueryResp = semanticLayerService.queryByReq(queryStructReq1, tom);
Assertions.assertNotNull(semanticQueryResp.getQueryAuthorization().getMessage());
Assertions.assertTrue(semanticQueryResp.getSql().contains("`user_name` = 'tom'"));
}
}

View File

@@ -14,9 +14,9 @@ public class QueryDimensionTest extends BaseTest {
queryDimValueReq.setModelId(1L);
queryDimValueReq.setBizName("department");
SemanticQueryResp queryResp = semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
SemanticQueryResp queryResp =
semanticLayerService.queryDimensionValue(queryDimValueReq, User.getFakeUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(4, queryResp.getResultList().size());
}
}

View File

@@ -19,8 +19,7 @@ import java.util.List;
public class QueryRuleTest extends BaseTest {
@Autowired
private QueryRuleService queryRuleService;
@Autowired private QueryRuleService queryRuleService;
private User user = User.getFakeUser();
@@ -93,7 +92,8 @@ public class QueryRuleTest extends BaseTest {
queryRuleService.addQueryRule(queryRuleReq2, user);
QueryRuleFilter queryRuleFilter = new QueryRuleFilter();
List<QueryRuleResp> queryRuleList = queryRuleService.getQueryRuleList(queryRuleFilter, user);
List<QueryRuleResp> queryRuleList =
queryRuleService.getQueryRuleList(queryRuleFilter, user);
queryRuleList.size();
}
}
}

View File

@@ -3,29 +3,27 @@ package com.tencent.supersonic.headless;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.enums.AuthType;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.api.pojo.response.DomainResp;
import com.tencent.supersonic.headless.api.pojo.response.ModelResp;
import com.tencent.supersonic.headless.api.pojo.response.DataSetResp;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.ModelService;
import com.tencent.supersonic.headless.server.service.DataSetService;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import java.util.List;
import java.util.stream.Collectors;
public class SchemaAuthTest extends BaseTest {
@Autowired
private DomainService domainService;
@Autowired private DomainService domainService;
@Autowired
private DataSetService dataSetService;
@Autowired private DataSetService dataSetService;
@Autowired
private ModelService modelService;
@Autowired private ModelService modelService;
@Test
public void test_getDomainList_alice() {
@@ -33,7 +31,8 @@ public class SchemaAuthTest extends BaseTest {
setDomainNotOpenToAll();
List<DomainResp> domainResps = domainService.getDomainListWithAdminAuth(user);
List<String> expectedDomainBizNames = Lists.newArrayList("supersonic", "singer");
Assertions.assertEquals(expectedDomainBizNames,
Assertions.assertEquals(
expectedDomainBizNames,
domainResps.stream().map(DomainResp::getBizName).collect(Collectors.toList()));
}
@@ -42,16 +41,19 @@ public class SchemaAuthTest extends BaseTest {
User user = DataUtils.getUserAlice();
List<ModelResp> modelResps = modelService.getModelListWithAuth(user, null, AuthType.ADMIN);
List<String> expectedModelBizNames = Lists.newArrayList("user_department", "singer");
Assertions.assertEquals(expectedModelBizNames,
Assertions.assertEquals(
expectedModelBizNames,
modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList()));
}
@Test
public void test_getVisibleModelList_alice() {
User user = DataUtils.getUserAlice();
List<ModelResp> modelResps = modelService.getModelListWithAuth(user, null, AuthType.VISIBLE);
List<ModelResp> modelResps =
modelService.getModelListWithAuth(user, null, AuthType.VISIBLE);
List<String> expectedModelBizNames = Lists.newArrayList("user_department", "singer");
Assertions.assertEquals(expectedModelBizNames,
Assertions.assertEquals(
expectedModelBizNames,
modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList()));
}
@@ -60,7 +62,8 @@ public class SchemaAuthTest extends BaseTest {
User user = DataUtils.getUserAlice();
List<DataSetResp> dataSetResps = dataSetService.getDataSetsInheritAuth(user, 0L);
List<String> expectedDataSetBizNames = Lists.newArrayList("singer");
Assertions.assertEquals(expectedDataSetBizNames,
Assertions.assertEquals(
expectedDataSetBizNames,
dataSetResps.stream().map(DataSetResp::getBizName).collect(Collectors.toList()));
}
@@ -69,7 +72,8 @@ public class SchemaAuthTest extends BaseTest {
User user = DataUtils.getUserJack();
List<DomainResp> domainResps = domainService.getDomainListWithAdminAuth(user);
List<String> expectedDomainBizNames = Lists.newArrayList("supersonic");
Assertions.assertEquals(expectedDomainBizNames,
Assertions.assertEquals(
expectedDomainBizNames,
domainResps.stream().map(DomainResp::getBizName).collect(Collectors.toList()));
}
@@ -77,9 +81,10 @@ public class SchemaAuthTest extends BaseTest {
public void test_getModelList_jack() {
User user = DataUtils.getUserJack();
List<ModelResp> modelResps = modelService.getModelListWithAuth(user, null, AuthType.ADMIN);
List<String> expectedModelBizNames = Lists.newArrayList("user_department",
"s2_pv_uv_statis", "s2_stay_time_statis");
Assertions.assertEquals(expectedModelBizNames,
List<String> expectedModelBizNames =
Lists.newArrayList("user_department", "s2_pv_uv_statis", "s2_stay_time_statis");
Assertions.assertEquals(
expectedModelBizNames,
modelResps.stream().map(ModelResp::getBizName).collect(Collectors.toList()));
}
@@ -88,8 +93,8 @@ public class SchemaAuthTest extends BaseTest {
User user = DataUtils.getUserJack();
List<DataSetResp> dataSetResps = dataSetService.getDataSetsInheritAuth(user, 0L);
List<String> expectedDataSetBizNames = Lists.newArrayList("s2", "singer");
Assertions.assertEquals(expectedDataSetBizNames,
Assertions.assertEquals(
expectedDataSetBizNames,
dataSetResps.stream().map(DataSetResp::getBizName).collect(Collectors.toList()));
}
}

View File

@@ -13,8 +13,7 @@ import java.util.List;
public class TagObjectTest extends BaseTest {
@Autowired
private TagObjectService tagObjectService;
@Autowired private TagObjectService tagObjectService;
@Test
void testCreateTagObject() throws Exception {
@@ -32,7 +31,8 @@ public class TagObjectTest extends BaseTest {
BeanUtils.copyProperties(tagObjectResp, tagObjectReqUpdate);
tagObjectReqUpdate.setName("艺人1");
tagObjectService.update(tagObjectReqUpdate, User.getFakeUser());
TagObjectResp tagObject = tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getFakeUser());
TagObjectResp tagObject =
tagObjectService.getTagObject(tagObjectReqUpdate.getId(), User.getFakeUser());
tagObjectService.delete(tagObject.getId(), User.getFakeUser(), false);
}
@@ -53,5 +53,4 @@ public class TagObjectTest extends BaseTest {
tagObjectReq.setBizName("new_singer");
return tagObjectReq;
}
}
}

View File

@@ -13,15 +13,14 @@ import org.springframework.beans.factory.annotation.Autowired;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
public class TagTest extends BaseTest {
@Autowired
private TagQueryService tagQueryService;
@Autowired private TagQueryService tagQueryService;
@Test
public void testQueryTagValue() throws Exception {
ItemValueReq itemValueReq = new ItemValueReq();
itemValueReq.setId(1L);
ItemValueResp itemValueResp = tagQueryService.queryTagValue(itemValueReq, User.getFakeUser());
ItemValueResp itemValueResp =
tagQueryService.queryTagValue(itemValueReq, User.getFakeUser());
Assertions.assertNotNull(itemValueResp);
}
}
}

View File

@@ -1,6 +1,5 @@
package com.tencent.supersonic.headless;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticTranslateResp;
@@ -18,8 +17,10 @@ public class TranslateTest extends BaseTest {
@Test
public void testSqlExplain() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService.translate(QueryReqBuilder.buildS2SQLReq(sql,
DataUtils.getMetricAgentView()), User.getFakeUser());
SemanticTranslateResp explain =
semanticLayerService.translate(
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
User.getFakeUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
@@ -29,11 +30,11 @@ public class TranslateTest extends BaseTest {
@Test
public void testStructExplain() throws Exception {
QueryStructReq queryStructReq = buildQueryStructReq(Arrays.asList("department"));
SemanticTranslateResp explain = semanticLayerService.translate(queryStructReq, User.getFakeUser());
SemanticTranslateResp explain =
semanticLayerService.translate(queryStructReq, User.getFakeUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());
assertTrue(explain.getQuerySQL().contains("department"));
assertTrue(explain.getQuerySQL().contains("pv"));
}
}

View File

@@ -51,9 +51,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setEndpoint(QianfanModelFactory.DEFAULT_ENDPOINT);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
assertThrows(
RuntimeException.class,
() -> {
chatModel.generate("hi");
});
}
@Test
@@ -65,9 +67,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
assertThrows(
RuntimeException.class,
() -> {
chatModel.generate("hi");
});
}
@Test
@@ -80,9 +84,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setApiKey(ParameterConfig.DEMO);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
assertThrows(
RuntimeException.class,
() -> {
chatModel.generate("hi");
});
}
@Test
@@ -94,9 +100,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setApiKey(ParameterConfig.DEMO);
ChatLanguageModel chatModel = ModelProvider.getChatModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
chatModel.generate("hi");
});
assertThrows(
RuntimeException.class,
() -> {
chatModel.generate("hi");
});
}
@Test
@@ -132,9 +140,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setApiKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
assertThrows(
RuntimeException.class,
() -> {
embeddingModel.embed("hi");
});
}
@Test
@@ -146,9 +156,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setApiKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
assertThrows(
RuntimeException.class,
() -> {
embeddingModel.embed("hi");
});
}
@Test
@@ -161,9 +173,11 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setSecretKey(ParameterConfig.DEMO);
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
assertThrows(
RuntimeException.class,
() -> {
embeddingModel.embed("hi");
});
}
@Test
@@ -175,8 +189,10 @@ public class ModelProviderTest extends BaseApplication {
modelConfig.setApiKey("e2724491714b3b2a0274e987905f1001.5JyHgf4vbZVJ7gC5");
EmbeddingModel embeddingModel = ModelProvider.getEmbeddingModel(modelConfig);
assertThrows(RuntimeException.class, () -> {
embeddingModel.embed("hi");
});
assertThrows(
RuntimeException.class,
() -> {
embeddingModel.embed("hi");
});
}
}

View File

@@ -44,6 +44,5 @@ public class AESEncryptionUtilTest {
System.out.println("after AES/ECB encrypt" + encryptStr);
String decryptStr = AESEncryptionUtil.aesDecryptECB(encryptStr);
System.out.println("after AES/ECB decrypt" + decryptStr);
}
}

View File

@@ -61,13 +61,15 @@ public class DataUtils {
}
public static SchemaElement getSchemaElement(String name) {
return SchemaElement.builder()
.name(name)
.build();
return SchemaElement.builder().name(name).build();
}
public static QueryFilter getFilter(String bizName, FilterOperatorEnum filterOperatorEnum,
Object value, String name, Long elementId) {
public static QueryFilter getFilter(
String bizName,
FilterOperatorEnum filterOperatorEnum,
Object value,
String name,
Long elementId) {
QueryFilter filter = new QueryFilter();
filter.setBizName(bizName);
filter.setOperator(filterOperatorEnum);
@@ -87,8 +89,12 @@ public class DataUtils {
return dateInfo;
}
public static DateConf getDateConf(DateConf.DateMode dateMode, Integer unit,
String period, String startDate, String endDate) {
public static DateConf getDateConf(
DateConf.DateMode dateMode,
Integer unit,
String period,
String startDate,
String endDate) {
DateConf dateInfo = new DateConf();
dateInfo.setUnit(unit);
dateInfo.setDateMode(dateMode);
@@ -98,7 +104,8 @@ public class DataUtils {
return dateInfo;
}
public static DateConf getDateConf(DateConf.DateMode dateMode, String startDate, String endDate) {
public static DateConf getDateConf(
DateConf.DateMode dateMode, String startDate, String endDate) {
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(dateMode);
dateInfo.setStartDate(startDate);
@@ -106,7 +113,8 @@ public class DataUtils {
return dateInfo;
}
public static DateConf getDateConf(DateConf.DateMode dateMode, String startDate, String endDate, int unit) {
public static DateConf getDateConf(
DateConf.DateMode dateMode, String startDate, String endDate, int unit) {
DateConf dateInfo = new DateConf();
dateInfo.setDateMode(dateMode);
dateInfo.setStartDate(startDate);
@@ -141,8 +149,14 @@ public class DataUtils {
RuleParserTool ruleQueryTool = new RuleParserTool();
ruleQueryTool.setType(AgentToolType.NL2SQL_RULE);
ruleQueryTool.setDataSetIds(Lists.newArrayList(-1L));
ruleQueryTool.setQueryModes(Lists.newArrayList("METRIC_ID", "METRIC_FILTER", "METRIC_MODEL",
"TAG_DETAIL", "TAG_LIST_FILTER", "TAG_ID"));
ruleQueryTool.setQueryModes(
Lists.newArrayList(
"METRIC_ID",
"METRIC_FILTER",
"METRIC_MODEL",
"TAG_DETAIL",
"TAG_LIST_FILTER",
"TAG_ID"));
return ruleQueryTool;
}

View File

@@ -6,11 +6,12 @@ import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.chat.knowledge.DatabaseMapResult;
import com.tencent.supersonic.headless.chat.knowledge.MapResult;
import com.tencent.supersonic.headless.chat.knowledge.helper.HanlpHelper;
import java.util.ArrayList;
import java.util.List;
import org.junit.Assert;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class HanlpTest {
@Test
@@ -36,4 +37,4 @@ public class HanlpTest {
HanlpHelper.transLetterOriginal(mapResults);
Assert.assertEquals(mapResults.size(), 2);
}
}
}