Merge remote-tracking branch 'origin/master'

This commit is contained in:
lxwcodemonkey
2024-11-30 15:29:35 +08:00
155 changed files with 3758 additions and 3552 deletions

View File

@@ -108,9 +108,9 @@ public abstract class S2BaseDemo implements CommandLineRunner {
}
}
abstract void doRun();
protected abstract void doRun();
abstract boolean checkNeedToRun();
protected abstract boolean checkNeedToRun();
protected DatabaseResp addDatabaseIfNotExist() {
List<DatabaseResp> databaseList = databaseService.getDatabaseList(defaultUser);
@@ -119,8 +119,8 @@ public abstract class S2BaseDemo implements CommandLineRunner {
}
String url = dataSourceProperties.getUrl();
DatabaseReq databaseReq = new DatabaseReq();
databaseReq.setName("H2数据库DEMO");
databaseReq.setDescription("样例数据库实例仅用于体验,正式使用请切换持久化数据库");
databaseReq.setName("S2数据库DEMO");
databaseReq.setDescription("样例数据库实例仅用于体验");
if (StringUtils.isNotBlank(url)
&& url.toLowerCase().contains(DataType.MYSQL.getFeature().toLowerCase())) {
databaseReq.setType(DataType.MYSQL.getFeature());

View File

@@ -7,26 +7,12 @@ import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.chat.server.processor.execute.DataInterpretProcessor;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.AppModule;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.TimeMode;
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
import com.tencent.supersonic.common.pojo.enums.*;
import com.tencent.supersonic.common.util.ChatAppManager;
import com.tencent.supersonic.headless.api.pojo.AggregateTypeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.DataSetDetail;
import com.tencent.supersonic.headless.api.pojo.DataSetModelConfig;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
import com.tencent.supersonic.headless.api.pojo.*;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.request.DataSetReq;
@@ -40,11 +26,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.*;
@Component
@Slf4j
@@ -59,8 +41,8 @@ public class S2CompanyDemo extends S2BaseDemo {
ModelResp model_brand = addModel_2(domain, demoDatabase);
ModelResp model_brand_revenue = addModel_3(domain, demoDatabase);
addModelRela(domain, model_company, model_brand, "company_id");
addModelRela(domain, model_brand, model_brand_revenue, "brand_id");
addModelRela(domain, model_brand, model_company, "company_id");
addModelRela(domain, model_brand_revenue, model_brand, "brand_id");
DataSetResp dataset = addDataSet(domain);
addAgent(dataset.getId());
@@ -70,7 +52,7 @@ public class S2CompanyDemo extends S2BaseDemo {
}
@Override
boolean checkNeedToRun() {
protected boolean checkNeedToRun() {
List<DomainResp> domainList = domainService.getDomainList();
for (DomainResp domainResp : domainList) {
if (domainResp.getBizName().equalsIgnoreCase("corporate")) {
@@ -124,8 +106,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT company_id,company_name,headquarter_address,"
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
modelDetail.setSqlQuery("SELECT * FROM company");
modelReq.setModelDetail(modelDetail);
ModelResp companyModel = modelService.createModel(modelReq, defaultUser);
@@ -164,8 +145,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT brand_id,brand_name,brand_established_time,"
+ "company_id,legal_representative,registered_capital FROM brand");
modelDetail.setSqlQuery("SELECT * FROM brand");
modelReq.setModelDetail(modelDetail);
ModelResp brandModel = modelService.createModel(modelReq, defaultUser);
@@ -205,8 +185,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT year_time,brand_id,revenue,profit,"
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM brand_revenue");
modelDetail.setSqlQuery("SELECT * FROM brand_revenue");
modelReq.setModelDetail(modelDetail);
return modelService.createModel(modelReq, defaultUser);
}
@@ -245,7 +224,7 @@ public class S2CompanyDemo extends S2BaseDemo {
modelRelaReq.setDomainId(domain.getId());
modelRelaReq.setFromModelId(fromModel.getId());
modelRelaReq.setToModelId(toModel.getId());
modelRelaReq.setJoinType("left join");
modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, defaultUser);
}
@@ -272,7 +251,6 @@ public class S2CompanyDemo extends S2BaseDemo {
Map<String, ChatApp> chatAppConfig =
Maps.newHashMap(ChatAppManager.getAllApps(AppModule.CHAT));
chatAppConfig.values().forEach(app -> app.setChatModelId(demoChatModel.getId()));
chatAppConfig.get(DataInterpretProcessor.APP_KEY).setEnable(true);
agent.setChatAppConfig(chatAppConfig);
agentService.createAgent(agent, defaultUser);

View File

@@ -57,7 +57,7 @@ public class S2SingerDemo extends S2BaseDemo {
}
@Override
boolean checkNeedToRun() {
protected boolean checkNeedToRun() {
List<DomainResp> domainList = domainService.getDomainList();
for (DomainResp domainResp : domainList) {
if (domainResp.getBizName().equalsIgnoreCase("singer")) {

View File

@@ -46,9 +46,10 @@ public class S2SmallTalkDemo extends S2BaseDemo {
}
@Override
boolean checkNeedToRun() {
protected boolean checkNeedToRun() {
List<String> agentNames =
agentService.getAgents().stream().map(Agent::getName).collect(Collectors.toList());
return !agentNames.contains("闲聊");
return !agentNames.contains("闲聊助手");
}
}

View File

@@ -79,8 +79,8 @@ public class S2VisitsDemo extends S2BaseDemo {
ModelResp userModel = addModel_1(s2Domain, demoDatabase);
ModelResp pvUvModel = addModel_2(s2Domain, demoDatabase);
ModelResp stayTimeModel = addModel_3(s2Domain, demoDatabase);
addModelRela(s2Domain, userModel, pvUvModel, "user_name");
addModelRela(s2Domain, userModel, stayTimeModel, "user_name");
addModelRela(s2Domain, pvUvModel, userModel, "user_name");
addModelRela(s2Domain, stayTimeModel, userModel, "user_name");
// create metrics and dimensions
DimensionResp departmentDimension = getDimension("department", userModel);
@@ -146,7 +146,8 @@ public class S2VisitsDemo extends S2BaseDemo {
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList("近15天超音数访问次数汇总", "按部门统计超音数的访问人数", "对比alice和lucy的停留时长",
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长"));
"过去30天访问次数最高的部门top3", "近1个月总访问次数超过100次的部门有几个", "过去半个月每个核心用户的总停留时长",
"今年以来访问次数最高的一天是哪一天"));
// configure tools
ToolConfig toolConfig = new ToolConfig();
@@ -198,6 +199,7 @@ public class S2VisitsDemo extends S2BaseDemo {
List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("部门", "department", DimensionType.categorical, 1));
// dimensions.add(new Dim("用户", "user_name", DimensionType.categorical, 1));
modelDetail.setDimensions(dimensions);
List<Field> fields = Lists.newArrayList();
fields.add(Field.builder().fieldName("user_name").dataType("Varchar").build());
@@ -382,9 +384,9 @@ public class S2VisitsDemo extends S2BaseDemo {
metricReq.setDescription("访问的用户个数");
metricReq.setAlias("UV,访问人数");
MetricDefineByFieldParams metricTypeParams = new MetricDefineByFieldParams();
metricTypeParams.setExpr("count(distinct user_id)");
metricTypeParams.setExpr("count(distinct user_name)");
List<FieldParam> fieldParams = new ArrayList<>();
fieldParams.add(new FieldParam("user_id"));
fieldParams.add(new FieldParam("user_name"));
metricTypeParams.setFields(fieldParams);
metricReq.setMetricDefineByFieldParams(metricTypeParams);
metricReq.setMetricDefineType(MetricDefineType.FIELD);

View File

@@ -26,15 +26,18 @@ com.tencent.supersonic.headless.chat.parser.llm.DataSetResolver=\
com.tencent.supersonic.headless.core.translator.converter.QueryConverter=\
com.tencent.supersonic.headless.core.translator.converter.DefaultDimValueConverter,\
com.tencent.supersonic.headless.core.translator.converter.SqlVariableParseConverter,\
com.tencent.supersonic.headless.core.translator.converter.CalculateAggConverter,\
com.tencent.supersonic.headless.core.translator.converter.ParserDefaultConverter
com.tencent.supersonic.headless.core.translator.converter.SqlVariableConverter,\
com.tencent.supersonic.headless.core.translator.converter.MetricRatioConverter,\
com.tencent.supersonic.headless.core.translator.converter.SqlQueryConverter,\
com.tencent.supersonic.headless.core.translator.converter.StructQueryConverter
com.tencent.supersonic.headless.core.translator.QueryOptimizer=\
com.tencent.supersonic.headless.core.translator.DetailQueryOptimizer
com.tencent.supersonic.headless.core.translator.optimizer.QueryOptimizer=\
com.tencent.supersonic.headless.core.translator.optimizer.DetailQueryOptimizer,\
com.tencent.supersonic.headless.core.translator.optimizer.DbDialectOptimizer,\
com.tencent.supersonic.headless.core.translator.optimizer.ResultLimitOptimizer
com.tencent.supersonic.headless.core.translator.QueryParser=\
com.tencent.supersonic.headless.core.translator.calcite.CalciteQueryParser
com.tencent.supersonic.headless.core.translator.parser.QueryParser=\
com.tencent.supersonic.headless.core.translator.parser.calcite.CalciteQueryParser
com.tencent.supersonic.headless.core.executor.QueryExecutor=\
com.tencent.supersonic.headless.core.executor.JdbcExecutor

View File

@@ -11,4 +11,24 @@ spring:
h2:
console:
path: /h2-console/semantic
enabled: true
enabled: true
### Comment out following lines if using MySQL
#spring:
# datasource:
# driver-class-name: com.mysql.cj.jdbc.Driver
# url: jdbc:mysql://localhost:3306/s2_database?user=root
# username: root
# password:
# sql:
# enabled: true
# mode: always
# username: root
# password:
# init:
# schema-locations: classpath:db/schema-mysql.sql,classpath:db/schema-mysql-demo.sql
# data-locations: classpath:db/data-mysql.sql,classpath:db/data-mysql-demo.sql
# h2:
# console:
# path: /h2-console/semantic
# enabled: true

View File

@@ -396,5 +396,5 @@ ALTER TABLE s2_agent DROP COLUMN `enable_memory_review`;
alter table s2_agent add column `enable_feedback` tinyint DEFAULT 1;
--20241116
alter table s2_agent add column `admin` varchar(1000);
alter table s2_agent add column `viewer` varchar(1000);
alter table s2_agent add column `admin` varchar(1000) COLLATE utf8_unicode_ci DEFAULT NULL;
alter table s2_agent add column `viewer` varchar(1000) COLLATE utf8_unicode_ci DEFAULT NULL;

View File

@@ -1,4 +1,4 @@
-------S2VisitsDemo
-- S2VisitsDemo
insert into s2_user_department (user_name, department) values ('jack','HR');
insert into s2_user_department (user_name, department) values ('tom','sales');
insert into s2_user_department (user_name, department) values ('lucy','marketing');
@@ -1019,7 +1019,7 @@ INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES (
INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES (DATE_SUB(CURRENT_DATE(), INTERVAL 15 DAY), 'lucy', '0.8124302447925607', 'p4');
INSERT INTO s2_stay_time_statis (imp_date, user_name, stay_hours, page) VALUES (DATE_SUB(CURRENT_DATE(), INTERVAL 8 DAY), 'lucy', '0.039935860913407284', 'p2');
-------S2ArtistDemo
-- S2ArtistDemo
INSERT INTO singer (singer_name, act_area, song_name, genre, js_play_cnt, down_cnt, favor_cnt)
VALUES ('周杰伦', '港台', '青花瓷', '国风', 1000000, 1000000, 1000000);

View File

@@ -1,4 +1,4 @@
-------S2VisitsDemo
-- S2VisitsDemo
CREATE TABLE IF NOT EXISTS `s2_user_department` (
`user_name` varchar(200) NOT NULL,
`department` varchar(200) NOT NULL
@@ -27,7 +27,7 @@ CREATE TABLE IF NOT EXISTS `singer` (
`favor_cnt` bigint DEFAULT NULL
)ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
-------S2ArtistDemo
-- S2ArtistDemo
CREATE TABLE IF NOT EXISTS `genre` (
`g_name` varchar(20) NOT NULL , -- genre name
`rating` INT ,

View File

@@ -15,6 +15,8 @@ CREATE TABLE IF NOT EXISTS `s2_agent` (
`created_at` datetime DEFAULT NULL,
`updated_by` varchar(100) COLLATE utf8_unicode_ci DEFAULT NULL,
`updated_at` datetime DEFAULT NULL,
`admin` varchar(1000) COLLATE utf8_unicode_ci DEFAULT NULL,
`viewer` varchar(1000) COLLATE utf8_unicode_ci DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci;
@@ -539,7 +541,7 @@ CREATE TABLE IF NOT EXISTS `s2_term` (
PRIMARY KEY (`id`)
) ENGINE = InnoDB DEFAULT CHARSET = utf8 COMMENT ='术语表';
CREATE TABLE `s2_user_token` (
CREATE TABLE IF NOT EXISTS `s2_user_token` (
`id` bigint NOT NULL AUTO_INCREMENT,
`name` VARCHAR(255) NOT NULL,
`user_name` VARCHAR(255) NOT NULL,

View File

@@ -12,14 +12,15 @@ import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.service.ChatModelService;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.response.ParseResp;
import com.tencent.supersonic.headless.api.pojo.response.QueryState;
import com.tencent.supersonic.headless.server.service.SchemaService;
import com.tencent.supersonic.util.DataUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import java.time.LocalDate;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -38,6 +39,8 @@ public class BaseTest extends BaseApplication {
protected AgentService agentService;
@Autowired
protected ChatModelService chatModelService;
@Autowired
protected SchemaService schemaService;
@Value("${s2.demo.enableLLM:false}")
protected boolean enableLLM;
@@ -107,4 +110,10 @@ public class BaseTest extends BaseApplication {
assertEquals(expectedParseInfo.getDateInfo(), actualParseInfo.getDateInfo());
}
protected SchemaElement getSchemaElementByName(Set<SchemaElement> elementSet, String name) {
Optional<SchemaElement> matchElement =
elementSet.stream().filter(e -> e.getName().equals(name)).findFirst();
return matchElement.orElse(null);
}
}

View File

@@ -5,6 +5,7 @@ import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
@@ -12,6 +13,7 @@ import com.tencent.supersonic.headless.chat.query.rule.detail.DetailDimensionQue
import com.tencent.supersonic.util.DataUtils;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@@ -19,8 +21,9 @@ import org.springframework.boot.test.context.SpringBootTest;
public class DetailTest extends BaseTest {
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void test_detail_dimension() throws Exception {
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.tagAgentId);
QueryResult actualResult = submitNewChat("周杰伦流派和代表作", DataUtils.singerAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -30,8 +33,11 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter =
DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS, "周杰伦", "歌手名", 8L);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
SchemaElement singerElement = getSchemaElementByName(schema.getDimensions(), "歌手名");
QueryFilter dimensionFilter = DataUtils.getFilter("singer_name", FilterOperatorEnum.EQUALS,
"周杰伦", "歌手名", singerElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions()
@@ -43,7 +49,7 @@ public class DetailTest extends BaseTest {
@Test
public void test_detail_filter() throws Exception {
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.tagAgentId);
QueryResult actualResult = submitNewChat("国风歌手", DataUtils.singerAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -53,8 +59,10 @@ public class DetailTest extends BaseTest {
expectedParseInfo.setQueryType(QueryType.DETAIL);
expectedParseInfo.setAggType(AggregateTypeEnum.NONE);
QueryFilter dimensionFilter =
DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风", "流派", 7L);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.singerDatasettId);
SchemaElement genreElement = getSchemaElementByName(schema.getDimensions(), "流派");
QueryFilter dimensionFilter = DataUtils.getFilter("genre", FilterOperatorEnum.EQUALS, "国风",
"流派", genreElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.getDimensions()
.addAll(Lists.newArrayList(SchemaElement.builder().name("歌手名").build()));

View File

@@ -5,14 +5,18 @@ import com.tencent.supersonic.common.pojo.DateConf;
import com.tencent.supersonic.common.pojo.enums.DatePeriodEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.QueryType;
import com.tencent.supersonic.headless.api.pojo.DataSetSchema;
import com.tencent.supersonic.headless.api.pojo.SchemaElement;
import com.tencent.supersonic.headless.api.pojo.SemanticParseInfo;
import com.tencent.supersonic.headless.api.pojo.request.QueryFilter;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricFilterQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricGroupByQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricModelQuery;
import com.tencent.supersonic.headless.chat.query.rule.metric.MetricTopNQuery;
import com.tencent.supersonic.util.DataUtils;
import org.junit.jupiter.api.Order;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import org.springframework.boot.test.context.SpringBootTest;
import java.text.DateFormat;
@@ -28,24 +32,16 @@ import static com.tencent.supersonic.common.pojo.enums.AggregateTypeEnum.SUM;
public class MetricTest extends BaseTest {
@Test
public void testMetric() throws Exception {
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.metricAgentId);
}
@Test
public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.metricAgentId);
public void testMetricModel() throws Exception {
QueryResult actualResult = submitNewChat("超音数 访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedResult.setQueryMode(MetricModelQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
@@ -56,8 +52,35 @@ public class MetricTest extends BaseTest {
}
@Test
public void testMetricFilter() throws Exception {
QueryResult actualResult = submitNewChat("alice的访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, unit, period, startDay, endDay));
expectedParseInfo.setQueryType(QueryType.AGGREGATE);
assertQueryResult(expectedResult, actualResult);
assert actualResult.getQueryResults().size() == 1;
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testMetricGroupBy() throws Exception {
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数和停留时长", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -67,6 +90,7 @@ public class MetricTest extends BaseTest {
expectedParseInfo.setAggType(NONE);
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("停留时长"));
expectedParseInfo.getDimensions().add(DataUtils.getSchemaElement("部门"));
expectedParseInfo.setDateInfo(DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 7,
@@ -79,7 +103,7 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricFilterCompare() throws Exception {
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("对比alice和lucy的访问次数", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -92,8 +116,11 @@ public class MetricTest extends BaseTest {
List<String> list = new ArrayList<>();
list.add("alice");
list.add("lucy");
QueryFilter dimensionFilter =
DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list, "用户", 2L);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
QueryFilter dimensionFilter = DataUtils.getFilter("user_name", FilterOperatorEnum.IN, list,
"用户", userElement.getId());
expectedParseInfo.getDimensionFilters().add(dimensionFilter);
expectedParseInfo.setDateInfo(
@@ -107,7 +134,7 @@ public class MetricTest extends BaseTest {
@Test
@Order(3)
public void testMetricTopN() throws Exception {
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("近3天访问次数最多的用户", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -128,7 +155,7 @@ public class MetricTest extends BaseTest {
@Test
public void testMetricGroupBySum() throws Exception {
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.metricAgentId);
QueryResult actualResult = submitNewChat("近7天超音数各部门的访问次数总和", DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
expectedResult.setChatContext(expectedParseInfo);
@@ -154,7 +181,7 @@ public class MetricTest extends BaseTest {
String dateStr = textFormat.format(format.parse(startDay));
QueryResult actualResult =
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.metricAgentId);
submitNewChat(String.format("alice在%s的访问次数", dateStr), DataUtils.productAgentId);
QueryResult expectedResult = new QueryResult();
SemanticParseInfo expectedParseInfo = new SemanticParseInfo();
@@ -163,9 +190,11 @@ public class MetricTest extends BaseTest {
expectedResult.setQueryMode(MetricFilterQuery.QUERY_MODE);
expectedParseInfo.setAggType(NONE);
DataSetSchema schema = schemaService.getDataSetSchema(DataUtils.productDatasetId);
SchemaElement userElement = getSchemaElementByName(schema.getDimensions(), "用户");
expectedParseInfo.getMetrics().add(DataUtils.getSchemaElement("访问次数"));
expectedParseInfo.getDimensionFilters().add(
DataUtils.getFilter("user_name", FilterOperatorEnum.EQUALS, "alice", "用户", 2L));
expectedParseInfo.getDimensionFilters().add(DataUtils.getFilter("user_name",
FilterOperatorEnum.EQUALS, "alice", "用户", userElement.getId()));
expectedParseInfo.setDateInfo(
DataUtils.getDateConf(DateConf.DateMode.BETWEEN, 1, period, startDay, startDay));

View File

@@ -5,7 +5,10 @@ import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.tencent.supersonic.chat.BaseTest;
import com.tencent.supersonic.chat.api.pojo.response.QueryResult;
import com.tencent.supersonic.chat.server.agent.*;
import com.tencent.supersonic.chat.server.agent.Agent;
import com.tencent.supersonic.chat.server.agent.AgentToolType;
import com.tencent.supersonic.chat.server.agent.DatasetTool;
import com.tencent.supersonic.chat.server.agent.ToolConfig;
import com.tencent.supersonic.common.config.ChatModel;
import com.tencent.supersonic.common.pojo.ChatApp;
import com.tencent.supersonic.common.pojo.User;
@@ -133,11 +136,28 @@ public class Text2SQLEval extends BaseTest {
assert result.getTextResult().contains("3");
}
@Test
public void test_detail_query() throws Exception {
long start = System.currentTimeMillis();
QueryResult result = submitNewChat("特斯拉旗下有哪些品牌", agentId);
durations.add(System.currentTimeMillis() - start);
assert result.getQueryColumns().size() >= 1;
assert result.getTextResult().contains("Model Y");
assert result.getTextResult().contains("Model 3");
}
public Agent getLLMAgent() {
Agent agent = new Agent();
agent.setName("Agent for Test");
ToolConfig toolConfig = new ToolConfig();
toolConfig.getTools().add(getDatasetTool());
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(DataUtils.productDatasetId));
toolConfig.getTools().add(datasetTool);
DatasetTool datasetTool2 = new DatasetTool();
datasetTool2.setType(AgentToolType.DATASET);
datasetTool2.setDataSetIds(Lists.newArrayList(DataUtils.companyDatasetId));
toolConfig.getTools().add(datasetTool2);
agent.setToolConfig(JSONObject.toJSONString(toolConfig));
// create chat model for this evaluation
ChatModel chatModel = new ChatModel();
@@ -154,11 +174,4 @@ public class Text2SQLEval extends BaseTest {
return agent;
}
private static DatasetTool getDatasetTool() {
DatasetTool datasetTool = new DatasetTool();
datasetTool.setType(AgentToolType.DATASET);
datasetTool.setDataSetIds(Lists.newArrayList(1L));
return datasetTool;
}
}

View File

@@ -1,14 +1,18 @@
package com.tencent.supersonic.headless;
import com.tencent.supersonic.common.pojo.Filter;
import com.tencent.supersonic.common.pojo.User;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.request.QueryMetricReq;
import com.tencent.supersonic.headless.api.pojo.request.QueryStructReq;
import com.tencent.supersonic.headless.api.pojo.response.SemanticQueryResp;
import com.tencent.supersonic.headless.server.service.MetricService;
import org.junit.Assert;
import org.junit.jupiter.api.Test;
import org.junitpioneer.jupiter.SetSystemProperty;
import org.springframework.beans.factory.annotation.Autowired;
import java.time.LocalDate;
import java.util.Arrays;
import static org.junit.Assert.assertThrows;
@@ -23,16 +27,24 @@ public class QueryByMetricTest extends BaseTest {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
.value(LocalDate.now().toString()).build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
}
@Test
@SetSystemProperty(key = "s2.test", value = "true")
public void testWithMetricAndDimensionNames() throws Exception {
QueryMetricReq queryMetricReq = new QueryMetricReq();
queryMetricReq.setMetricNames(Arrays.asList("停留时长", "访问次数"));
queryMetricReq.setDimensionNames(Arrays.asList("用户", "部门"));
queryMetricReq.getFilters()
.add(Filter.builder().name("数据日期").operator(FilterOperatorEnum.MINOR_THAN_EQUALS)
.relation(Filter.Relation.FILTER).value(LocalDate.now().toString())
.build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
@@ -44,6 +56,9 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricNames(Arrays.asList("stay_hours", "pv"));
queryMetricReq.setDimensionNames(Arrays.asList("user_name", "department"));
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
.value(LocalDate.now().toString()).build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());
@@ -61,6 +76,9 @@ public class QueryByMetricTest extends BaseTest {
queryMetricReq.setDomainId(1L);
queryMetricReq.setMetricIds(Arrays.asList(1L, 3L));
queryMetricReq.setDimensionIds(Arrays.asList(1L, 2L));
queryMetricReq.getFilters().add(Filter.builder().name("imp_date")
.operator(FilterOperatorEnum.MINOR_THAN_EQUALS).relation(Filter.Relation.FILTER)
.value(LocalDate.now().toString()).build());
SemanticQueryResp queryResp = queryByMetric(queryMetricReq, User.getDefaultUser());
Assert.assertNotNull(queryResp.getResultList());
Assert.assertEquals(6, queryResp.getResultList().size());

View File

@@ -18,7 +18,7 @@ public class TranslateTest extends BaseTest {
public void testSqlExplain() throws Exception {
String sql = "SELECT 部门, SUM(访问次数) AS 访问次数 FROM 超音数PVUV统计 GROUP BY 部门 ";
SemanticTranslateResp explain = semanticLayerService.translate(
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.getMetricAgentView()),
QueryReqBuilder.buildS2SQLReq(sql, DataUtils.productDatasetId),
User.getDefaultUser());
assertNotNull(explain);
assertNotNull(explain.getQuerySQL());

View File

@@ -15,10 +15,15 @@ import static java.time.LocalDate.now;
public class DataUtils {
public static final Integer metricAgentId = 1;
public static final Integer tagAgentId = 2;
public static final Integer productAgentId = 1;
public static final Integer companyAgentId = 2;
public static final Integer singerAgentId = 3;
public static final Long productDatasetId = 1L;
public static final Long companyDatasetId = 2L;
public static final Long singerDatasettId = 3L;
public static final Integer ONE_TURNS_CHAT_ID = 10;
public static final Integer MULTI_TURNS_CHAT_ID = 11;
private static final User user_test = User.getDefaultUser();
public static User getUser() {
@@ -40,7 +45,7 @@ public class DataUtils {
public static ChatParseReq getChatParseReq(Integer id, String query, boolean enableLLM) {
ChatParseReq chatParseReq = new ChatParseReq();
chatParseReq.setQueryText(query);
chatParseReq.setAgentId(metricAgentId);
chatParseReq.setAgentId(productAgentId);
chatParseReq.setChatId(id);
chatParseReq.setUser(user_test);
chatParseReq.setDisableLLM(!enableLLM);
@@ -92,7 +97,4 @@ public class DataUtils {
return result;
}
public static Long getMetricAgentView() {
return 1L;
}
}

View File

@@ -11,7 +11,12 @@ public class LLMConfigUtils {
OPENAI_GLM(false),
OLLAMA_LLAMA3(true),
OLLAMA_QWEN2(true),
OLLAMA_QWEN25(true);
OLLAMA_QWEN25_7B(true),
OLLAMA_QWEN25_14B(true),
OLLAMA_QWEN25_CODE_7B(true),
OLLAMA_QWEN25_CODE_3B(true),
OLLAMA_GLM4(true);
public boolean isOllam;
@@ -35,10 +40,26 @@ public class LLMConfigUtils {
baseUrl = "http://localhost:11434";
modelName = "qwen2:7b";
break;
case OLLAMA_QWEN25:
case OLLAMA_QWEN25_7B:
baseUrl = "http://localhost:11434";
modelName = "qwen2.5:7b";
break;
case OLLAMA_QWEN25_14B:
baseUrl = "http://localhost:11434";
modelName = "qwen2.5:14b";
break;
case OLLAMA_QWEN25_CODE_7B:
baseUrl = "http://localhost:11434";
modelName = "qwen2.5-coder:7b";
break;
case OLLAMA_QWEN25_CODE_3B:
baseUrl = "http://localhost:11434";
modelName = "qwen2.5-coder:3b";
break;
case OLLAMA_GLM4:
baseUrl = "http://localhost:11434";
modelName = "glm4:latest";
break;
case OPENAI_GLM:
baseUrl = "https://open.bigmodel.cn/api/pas/v4/";
apiKey = "REPLACE_WITH_YOUR_KEY";

View File

@@ -1,14 +1,34 @@
spring:
datasource:
driver-class-name: org.h2.Driver
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=100
url: jdbc:h2:mem:semantic;DATABASE_TO_UPPER=false;QUERY_TIMEOUT=30
username: root
password: semantic
sql:
init:
schema-locations: classpath:db/schema-h2.sql
data-locations: classpath:db/data-h2.sql
schema-locations: classpath:db/schema-h2.sql,classpath:db/schema-h2-demo.sql
data-locations: classpath:db/data-h2.sql,classpath:db/data-h2-demo.sql
h2:
console:
path: /h2-console/semantic
enabled: true
enabled: true
### Comment out following lines if using MySQL
#spring:
# datasource:
# driver-class-name: com.mysql.cj.jdbc.Driver
# url: jdbc:mysql://localhost:3306/s2_database?user=root
# username: root
# password:
# sql:
# enabled: true
# mode: always
# username: root
# password:
# init:
# schema-locations: classpath:db/schema-mysql.sql,classpath:db/schema-mysql-demo.sql
# data-locations: classpath:db/data-mysql.sql,classpath:db/data-mysql-demo.sql
# h2:
# console:
# path: /h2-console/semantic
# enabled: true

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,85 @@
-------S2VisitsDemo
CREATE TABLE IF NOT EXISTS `s2_user_department` (
`user_name` varchar(200) NOT NULL,
`department` varchar(200) NOT NULL, -- department of user
PRIMARY KEY (`user_name`,`department`)
);
COMMENT ON TABLE s2_user_department IS 'user_department_info';
CREATE TABLE IF NOT EXISTS `s2_pv_uv_statis` (
`imp_date` varchar(200) NOT NULL,
`user_name` varchar(200) NOT NULL,
`page` varchar(200) NOT NULL
);
COMMENT ON TABLE s2_pv_uv_statis IS 's2_pv_uv_statis';
CREATE TABLE IF NOT EXISTS `s2_stay_time_statis` (
`imp_date` varchar(200) NOT NULL,
`user_name` varchar(200) NOT NULL,
`stay_hours` DOUBLE NOT NULL,
`page` varchar(200) NOT NULL
);
COMMENT ON TABLE s2_stay_time_statis IS 's2_stay_time_statis_info';
-------S2ArtistDemo
CREATE TABLE IF NOT EXISTS `singer` (
`singer_name` varchar(200) NOT NULL,
`act_area` varchar(200) NOT NULL,
`song_name` varchar(200) NOT NULL,
`genre` varchar(200) NOT NULL,
`js_play_cnt` bigINT DEFAULT NULL,
`down_cnt` bigINT DEFAULT NULL,
`favor_cnt` bigINT DEFAULT NULL,
PRIMARY KEY (`singer_name`)
);
COMMENT ON TABLE singer IS 'singer_info';
CREATE TABLE IF NOT EXISTS `genre` (
`g_name` varchar(20) NOT NULL , -- genre name
`rating` INT ,
`most_popular_in` varchar(50) ,
PRIMARY KEY (`g_name`)
);
COMMENT ON TABLE genre IS 'genre';
CREATE TABLE IF NOT EXISTS `artist` (
`artist_name` varchar(50) NOT NULL , -- genre name
`citizenship` varchar(20) ,
`gender` varchar(20) ,
`g_name` varchar(50),
PRIMARY KEY (`artist_name`,`citizenship`)
);
COMMENT ON TABLE artist IS 'artist';
-------S2CompanyDemo
CREATE TABLE IF NOT EXISTS `company` (
`company_id` varchar(50) NOT NULL ,
`company_name` varchar(50) NOT NULL ,
`headquarter_address` varchar(50) NOT NULL ,
`company_established_time` varchar(20) NOT NULL ,
`founder` varchar(20) NOT NULL ,
`ceo` varchar(20) NOT NULL ,
`annual_turnover` bigint(15) ,
`employee_count` int(7) ,
PRIMARY KEY (`company_id`)
);
CREATE TABLE IF NOT EXISTS `brand` (
`brand_id` varchar(50) NOT NULL ,
`brand_name` varchar(50) NOT NULL ,
`brand_established_time` varchar(20) NOT NULL ,
`company_id` varchar(50) NOT NULL ,
`legal_representative` varchar(20) NOT NULL ,
`registered_capital` bigint(15) ,
PRIMARY KEY (`brand_id`)
);
CREATE TABLE IF NOT EXISTS `brand_revenue` (
`year_time` varchar(10) NOT NULL ,
`brand_id` varchar(50) NOT NULL ,
`revenue` bigint(15) NOT NULL,
`profit` bigint(15) NOT NULL ,
`revenue_growth_year_on_year` double NOT NULL ,
`profit_growth_year_on_year` double NOT NULL
);

View File

@@ -21,7 +21,7 @@ s2:
date: true
demo:
names: S2VisitsDemo,S2SingerDemo
names: S2VisitsDemo,S2SingerDemo,S2CompanyDemo
enableLLM: false
authentication: