mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-11 03:58:14 +00:00
(improvement)(chat) fix embedding null pointer (#712)
This commit is contained in:
@@ -11,6 +11,8 @@ import com.tencent.supersonic.chat.core.utils.HanlpHelper;
|
|||||||
import com.tencent.supersonic.common.util.ContextUtils;
|
import com.tencent.supersonic.common.util.ContextUtils;
|
||||||
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
import com.tencent.supersonic.common.util.embedding.Retrieval;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Objects;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.apache.commons.lang3.StringUtils;
|
import org.apache.commons.lang3.StringUtils;
|
||||||
|
|
||||||
@@ -37,7 +39,9 @@ public class EmbeddingMapper extends BaseMapper {
|
|||||||
|
|
||||||
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
SchemaElement schemaElement = JSONObject.parseObject(JSONObject.toJSONString(matchResult.getMetadata()),
|
||||||
SchemaElement.class);
|
SchemaElement.class);
|
||||||
|
if (Objects.isNull(matchResult.getMetadata())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
String modelIdStr = matchResult.getMetadata().get("modelId");
|
String modelIdStr = matchResult.getMetadata().get("modelId");
|
||||||
if (StringUtils.isBlank(modelIdStr)) {
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
continue;
|
continue;
|
||||||
|
|||||||
@@ -106,7 +106,9 @@ public class EmbeddingMatchStrategy extends BaseMatchStrategy<EmbeddingResult> {
|
|||||||
if (StringUtils.isBlank(modelIdStr)) {
|
if (StringUtils.isBlank(modelIdStr)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return detectModelIds.contains(Long.parseLong(modelIdStr));
|
//return detectModelIds.contains(Long.parseLong(modelIdStr));
|
||||||
|
Double modelId = Double.parseDouble(modelIdStr);
|
||||||
|
return detectModelIds.contains(modelId.longValue());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
# 评测流程
|
# 评测流程
|
||||||
|
|
||||||
1. 正常启动项目(必须包括LLM服务)
|
1. 正常启动项目(必须包括LLM服务)
|
||||||
2. 将要评测问题放到evalution/data目录下,如:internet.txt;将要评测问题对应的SQL也放到evalution/data目录下,如:gold_example_dusql.txt。
|
2. 执行evalution.sh脚本,主要包括构建表数据、获取模型预测结果,执行对比逻辑。可以在命令行看到执行准确率,错误case会写到同目录的error_case.json文件中。
|
||||||
3. 执行evalution.sh脚本,主要包括构建表数据、获取模型预测结果,执行对比逻辑。可以在命令行看到执行准确率,错误case会写到同目录的eval.json文件中。
|
|
||||||
|
|
||||||
# 评测意义
|
# 评测意义
|
||||||
|
|
||||||
制定大模型评估框架对于提示词或代码更改的影响至关重要,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。 随着产品规模的扩大,如果没有这样的框架,就会发现自己在盲目地调整黑匣子,有助于帮助我们减少问题、提高效率、增强模型能力。大模型评测的核心目的是确定模型的"聪明"程度,深入探讨其性能、特点和局限性,为行业应用提供方向。
|
制定评估工具对于提示词或代码更改的影响至关重要,方便supersonic快速对接其他模型、更改配置,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。
|
||||||
通过评测,我们可以更好地了解模型的性能、特点、价值、局限性和潜在风险,并为其发展和应用提供支持,具有重要意义。
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ def get_pred_result():
|
|||||||
config_file=current_directory+"/config/config.yaml"
|
config_file=current_directory+"/config/config.yaml"
|
||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
input_path=current_directory+"/data/"+config["domain"]+".txt"
|
input_path=current_directory+"/data/"+"internet.txt"
|
||||||
pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt"
|
pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt"
|
||||||
pred_sql_exist=os.path.exists(pred_sql_path)
|
pred_sql_exist=os.path.exists(pred_sql_path)
|
||||||
if pred_sql_exist:
|
if pred_sql_exist:
|
||||||
|
|||||||
@@ -288,7 +288,7 @@ def build_table():
|
|||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
db_path=current_directory+"/data/"
|
db_path=current_directory+"/data/"
|
||||||
db_file=db_path+config["domain"]+".db"
|
db_file=db_path+"internet.db"
|
||||||
db_exist=os.path.exists(db_file)
|
db_exist=os.path.exists(db_file)
|
||||||
if db_exist:
|
if db_exist:
|
||||||
os.remove(db_file)
|
os.remove(db_file)
|
||||||
@@ -301,7 +301,7 @@ if __name__ == '__main__':
|
|||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
db_path=current_directory+"/data/"
|
db_path=current_directory+"/data/"
|
||||||
db_file=db_path+config["domain"]+".db"
|
db_file=db_path+"internet.db"
|
||||||
db_exist=os.path.exists(db_file)
|
db_exist=os.path.exists(db_file)
|
||||||
if db_exist:
|
if db_exist:
|
||||||
os.remove(db_file)
|
os.remove(db_file)
|
||||||
@@ -311,30 +311,3 @@ if __name__ == '__main__':
|
|||||||
#build_china_travel_agency(path)
|
#build_china_travel_agency(path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# sql="SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id"
|
|
||||||
# query_sql(sql)
|
|
||||||
# sql =" select * from ( SELECT brand_name, company_name, sum(revenue_proportion), sum(annual_turnover) FROM (select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover`, `revenue_proportion`from(select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`from(select `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`, `sys_imp_date`, `company_name`, `brand_name`from(select `src1_brand`.`sys_imp_date` as `sys_imp_date`, `src1_company`.`annual_turnover` as `annual_turnover`, `src1_company_revenue`.`revenue_proportion` as `revenue_proportion`, `src1_brand`.`company_id` as `company_id`, `src1_company`.`company_name` as `company_name`, `src1_brand`.`brand_name` as `brand_name`, `src1_brand`.`brand_id` as `brand_id`from(select `annual_turnover` as `annual_turnover`, `imp_date` as `sys_imp_date`, `company_id`, `company_name` as `company_name`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `company_name`, `headquarter_address`, `company_established_time`, `founder`, `ceo`, `annual_turnover`, `employee_count`from`company`) as `company`) as `src1_company`inner join (select `revenue_proportion` as `revenue_proportion`, `imp_date` as `sys_imp_date`, `company_id`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `brand_id`, `revenue_proportion`, `profit_proportion`, `expenditure_proportion`from`company_revenue`) as `company_revenue`) as `src1_company_revenue` on `src1_company`.`company_id` = `src1_company_revenue`.`company_id`inner join (select `imp_date` as `sys_imp_date`, `company_id`, `brand_name` as `brand_name`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `brand_id`, `brand_name`, `brand_established_time`, `company_id`, `legal_representative`, `registered_capital`from`brand`) as `brand`) as `src1_brand` on `src1_company`.`company_id` = `src1_brand`.`company_id`) as `src11_`) as `company_company_revenue_brand_0`) as `company_company_revenue_brand_1`) t_103 WHERE sys_imp_date = '2024-01-11' GROUP BY brand_name, company_name, sys_imp_date ORDER BY sum(revenue_proportion) DESC ) a limit 1000 "
|
|
||||||
# a=query_sql(sql)
|
|
||||||
# print(a[0][0])
|
|
||||||
|
|
||||||
|
|
||||||
# import pymysql
|
|
||||||
#
|
|
||||||
# db = pymysql.connect(host="11.154.212.211", user="semantic", password="semantic2023", database="internet", port=3306)
|
|
||||||
#
|
|
||||||
# # 使用cursor()方法获取操作游标
|
|
||||||
# cursor = db.cursor()
|
|
||||||
#
|
|
||||||
# # 使用execute方法执行SQL语句
|
|
||||||
# cursor.execute("select * from company")
|
|
||||||
#
|
|
||||||
# # 使用 fetchone() 方法获取一条数据
|
|
||||||
# data = cursor.fetchone()
|
|
||||||
# print(data)
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
chat_id: 3
|
chat_id: 3
|
||||||
agent_id: 4
|
agent_id: 4
|
||||||
domain: internet
|
|
||||||
url: http://localhost:9080
|
url: http://localhost:9080
|
||||||
|
|||||||
17
evaluation/error_case.json
Normal file
17
evaluation/error_case.json
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
[
|
||||||
|
{
|
||||||
|
"query": "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司的年营业额",
|
||||||
|
"gold_sql": "SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
|
||||||
|
"pred_sql": "select * from tablea"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例",
|
||||||
|
"gold_sql": "SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
|
||||||
|
"pred_sql": "select * from tablea"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"query": "在各公司所有品牌收入排名中,给出每一个品牌和其法人,其所在公司以及收入占该公司的总收入比例",
|
||||||
|
"gold_sql": "SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
|
||||||
|
"pred_sql": "select * from tablea"
|
||||||
|
}
|
||||||
|
]
|
||||||
@@ -513,7 +513,6 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
|||||||
db_name = db
|
db_name = db
|
||||||
# db = os.path.join(db_dir, db, db + ".sqlite")
|
# db = os.path.join(db_dir, db, db + ".sqlite")
|
||||||
db = os.path.join(db_dir,db + ".db")
|
db = os.path.join(db_dir,db + ".db")
|
||||||
print(db)
|
|
||||||
schema = Schema(get_schema(db))
|
schema = Schema(get_schema(db))
|
||||||
g_sql = get_sql(schema, g_str)
|
g_sql = get_sql(schema, g_str)
|
||||||
hardness = evaluator.eval_hardness(g_sql)
|
hardness = evaluator.eval_hardness(g_sql)
|
||||||
@@ -597,7 +596,7 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
|||||||
print_scores(scores, etype)
|
print_scores(scores, etype)
|
||||||
print(scores['all']['exec'])
|
print(scores['all']['exec'])
|
||||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||||
file_name=current_directory+"/eval.json"
|
file_name=current_directory+"/error_case.json"
|
||||||
json_exist=os.path.exists(file_name)
|
json_exist=os.path.exists(file_name)
|
||||||
if json_exist:
|
if json_exist:
|
||||||
os.remove(file_name)
|
os.remove(file_name)
|
||||||
@@ -884,11 +883,11 @@ def get_evaluation_result():
|
|||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
db_dir=current_directory+"/data"
|
db_dir=current_directory+"/data"
|
||||||
db_path=current_directory+"/data/"
|
db_path=current_directory+"/data/"
|
||||||
db_file=db_path+config["domain"]+".db"
|
db_file=db_path+"internet.db"
|
||||||
pred = current_directory+"/data/"+"pred_example_dusql.txt"
|
pred = current_directory+"/data/"+"pred_example_dusql.txt"
|
||||||
gold = current_directory+"/data/"+"gold_example_dusql.txt"
|
gold = current_directory+"/data/"+"gold_example_dusql.txt"
|
||||||
table= current_directory+"/data/"+"tables_dusql.json"
|
table= current_directory+"/data/"+"tables_dusql.json"
|
||||||
query_path=current_directory+"/data/"+config["domain"]+".txt"
|
query_path=current_directory+"/data/"+"internet.txt"
|
||||||
etype="exec"
|
etype="exec"
|
||||||
kmaps = build_foreign_key_map_from_json(table)
|
kmaps = build_foreign_key_map_from_json(table)
|
||||||
|
|
||||||
@@ -900,7 +899,7 @@ def remove_unused_file():
|
|||||||
with open(config_file, 'r') as file:
|
with open(config_file, 'r') as file:
|
||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
db_path=current_directory+"/data/"
|
db_path=current_directory+"/data/"
|
||||||
db_file=db_path+config["domain"]+".db"
|
db_file=db_path+"internet.db"
|
||||||
pred_file = current_directory+"/data/"+"pred_example_dusql.txt"
|
pred_file = current_directory+"/data/"+"pred_example_dusql.txt"
|
||||||
|
|
||||||
db_exist=os.path.exists(db_file)
|
db_exist=os.path.exists(db_file)
|
||||||
|
|||||||
@@ -6,18 +6,28 @@ import com.tencent.supersonic.common.pojo.JoinCondition;
|
|||||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
|
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Dim;
|
import com.tencent.supersonic.headless.api.pojo.Dim;
|
||||||
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.ViewDetail;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.ViewModelConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
import com.tencent.supersonic.headless.api.pojo.Identify;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Measure;
|
import com.tencent.supersonic.headless.api.pojo.Measure;
|
||||||
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
|
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TagTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
|
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.ViewReq;
|
||||||
import com.tencent.supersonic.headless.server.service.DomainService;
|
import com.tencent.supersonic.headless.server.service.DomainService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
@@ -40,6 +50,9 @@ public class BenchMarkDemoDataLoader {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private ModelRelaService modelRelaService;
|
private ModelRelaService modelRelaService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ViewService viewService;
|
||||||
|
|
||||||
public void doRun() {
|
public void doRun() {
|
||||||
try {
|
try {
|
||||||
addDomain();
|
addDomain();
|
||||||
@@ -47,6 +60,7 @@ public class BenchMarkDemoDataLoader {
|
|||||||
addModel_2();
|
addModel_2();
|
||||||
addModel_3();
|
addModel_3();
|
||||||
addModel_4();
|
addModel_4();
|
||||||
|
addView_1();
|
||||||
addModelRela_1();
|
addModelRela_1();
|
||||||
addModelRela_2();
|
addModelRela_2();
|
||||||
addModelRela_3();
|
addModelRela_3();
|
||||||
@@ -194,6 +208,42 @@ public class BenchMarkDemoDataLoader {
|
|||||||
modelService.createModel(modelReq, user);
|
modelService.createModel(modelReq, user);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void addView_1() {
|
||||||
|
ViewReq viewReq = new ViewReq();
|
||||||
|
viewReq.setName("cspider");
|
||||||
|
viewReq.setBizName("singer");
|
||||||
|
viewReq.setDomainId(3L);
|
||||||
|
viewReq.setDescription("包含cspider数据集相关标签和指标信息");
|
||||||
|
viewReq.setAdmins(Lists.newArrayList("admin"));
|
||||||
|
List<ViewModelConfig> viewModelConfigs = Lists.newArrayList(
|
||||||
|
new ViewModelConfig(5L, Lists.newArrayList(8L), Lists.newArrayList()),
|
||||||
|
new ViewModelConfig(6L, Lists.newArrayList(9L, 10L), Lists.newArrayList()),
|
||||||
|
new ViewModelConfig(7L, Lists.newArrayList(11L, 12L), Lists.newArrayList()),
|
||||||
|
new ViewModelConfig(8L, Lists.newArrayList(13L, 14L, 15L), Lists.newArrayList(8L, 9L))
|
||||||
|
);
|
||||||
|
ViewDetail viewDetail = new ViewDetail();
|
||||||
|
viewDetail.setViewModelConfigs(viewModelConfigs);
|
||||||
|
viewReq.setViewDetail(viewDetail);
|
||||||
|
viewReq.setTypeEnum(TypeEnums.VIEW);
|
||||||
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
|
TagTypeDefaultConfig tagTypeDefaultConfig = new TagTypeDefaultConfig();
|
||||||
|
TimeDefaultConfig tagTimeDefaultConfig = new TimeDefaultConfig();
|
||||||
|
tagTimeDefaultConfig.setTimeMode(TimeMode.LAST);
|
||||||
|
tagTimeDefaultConfig.setUnit(7);
|
||||||
|
tagTypeDefaultConfig.setTimeDefaultConfig(tagTimeDefaultConfig);
|
||||||
|
tagTypeDefaultConfig.setDimensionIds(Lists.newArrayList());
|
||||||
|
tagTypeDefaultConfig.setMetricIds(Lists.newArrayList());
|
||||||
|
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
|
||||||
|
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||||
|
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
||||||
|
timeDefaultConfig.setUnit(7);
|
||||||
|
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
|
||||||
|
queryConfig.setTagTypeDefaultConfig(tagTypeDefaultConfig);
|
||||||
|
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
|
||||||
|
viewReq.setQueryConfig(queryConfig);
|
||||||
|
viewService.save(viewReq, User.getFakeUser());
|
||||||
|
}
|
||||||
|
|
||||||
public void addModelRela_1() {
|
public void addModelRela_1() {
|
||||||
List<JoinCondition> joinConditions = Lists.newArrayList();
|
List<JoinCondition> joinConditions = Lists.newArrayList();
|
||||||
joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS));
|
joinConditions.add(new JoinCondition("g_name", "g_name", FilterOperatorEnum.EQUALS));
|
||||||
@@ -253,4 +303,4 @@ public class BenchMarkDemoDataLoader {
|
|||||||
modelRelaReq.setJoinConditions(joinConditions);
|
modelRelaReq.setJoinConditions(joinConditions);
|
||||||
modelRelaService.save(modelRelaReq, user);
|
modelRelaService.save(modelRelaReq, user);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ public class ChatDemoLoader implements CommandLineRunner {
|
|||||||
LLMParserTool llmParserTool = new LLMParserTool();
|
LLMParserTool llmParserTool = new LLMParserTool();
|
||||||
llmParserTool.setId("1");
|
llmParserTool.setId("1");
|
||||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||||
llmParserTool.setViewIds(Lists.newArrayList(5L, 6L, 7L, 8L));
|
llmParserTool.setViewIds(Lists.newArrayList(3L));
|
||||||
agentConfig.getTools().add(llmParserTool);
|
agentConfig.getTools().add(llmParserTool);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,7 +242,7 @@ public class ChatDemoLoader implements CommandLineRunner {
|
|||||||
LLMParserTool llmParserTool = new LLMParserTool();
|
LLMParserTool llmParserTool = new LLMParserTool();
|
||||||
llmParserTool.setId("1");
|
llmParserTool.setId("1");
|
||||||
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
|
||||||
llmParserTool.setViewIds(Lists.newArrayList(9L, 10L, 11L, 12L));
|
llmParserTool.setViewIds(Lists.newArrayList(4L));
|
||||||
agentConfig.getTools().add(llmParserTool);
|
agentConfig.getTools().add(llmParserTool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import com.tencent.supersonic.auth.api.authentication.pojo.User;
|
|||||||
import com.tencent.supersonic.common.pojo.JoinCondition;
|
import com.tencent.supersonic.common.pojo.JoinCondition;
|
||||||
import com.tencent.supersonic.common.pojo.ModelRela;
|
import com.tencent.supersonic.common.pojo.ModelRela;
|
||||||
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TimeMode;
|
||||||
|
import com.tencent.supersonic.common.pojo.enums.TypeEnums;
|
||||||
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
|
||||||
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
|
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
|
||||||
@@ -13,14 +15,21 @@ import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
|
|||||||
import com.tencent.supersonic.headless.api.pojo.Identify;
|
import com.tencent.supersonic.headless.api.pojo.Identify;
|
||||||
import com.tencent.supersonic.headless.api.pojo.Measure;
|
import com.tencent.supersonic.headless.api.pojo.Measure;
|
||||||
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
|
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.ViewDetail;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.ViewModelConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.QueryConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.MetricTypeDefaultConfig;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.TimeDefaultConfig;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
|
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
|
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
|
||||||
|
import com.tencent.supersonic.headless.api.pojo.request.ViewReq;
|
||||||
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
|
||||||
import com.tencent.supersonic.headless.server.service.DomainService;
|
import com.tencent.supersonic.headless.server.service.DomainService;
|
||||||
import com.tencent.supersonic.headless.server.service.MetricService;
|
import com.tencent.supersonic.headless.server.service.MetricService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
import com.tencent.supersonic.headless.server.service.ModelRelaService;
|
||||||
import com.tencent.supersonic.headless.server.service.ModelService;
|
import com.tencent.supersonic.headless.server.service.ModelService;
|
||||||
|
import com.tencent.supersonic.headless.server.service.ViewService;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
@@ -46,6 +55,9 @@ public class DuSQLDemoDataLoader {
|
|||||||
@Autowired
|
@Autowired
|
||||||
private MetricService metricService;
|
private MetricService metricService;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private ViewService viewService;
|
||||||
|
|
||||||
public void doRun() {
|
public void doRun() {
|
||||||
try {
|
try {
|
||||||
addDomain();
|
addDomain();
|
||||||
@@ -53,6 +65,7 @@ public class DuSQLDemoDataLoader {
|
|||||||
addModel_2();
|
addModel_2();
|
||||||
addModel_3();
|
addModel_3();
|
||||||
addModel_4();
|
addModel_4();
|
||||||
|
addView_1();
|
||||||
addModelRela_1();
|
addModelRela_1();
|
||||||
addModelRela_2();
|
addModelRela_2();
|
||||||
addModelRela_3();
|
addModelRela_3();
|
||||||
@@ -241,6 +254,34 @@ public class DuSQLDemoDataLoader {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void addView_1() {
|
||||||
|
ViewReq viewReq = new ViewReq();
|
||||||
|
viewReq.setName("DuSQL 互联网企业");
|
||||||
|
viewReq.setBizName("internet");
|
||||||
|
viewReq.setDomainId(4L);
|
||||||
|
viewReq.setDescription("DuSQL互联网企业数据源相关的指标和维度等");
|
||||||
|
viewReq.setAdmins(Lists.newArrayList("admin"));
|
||||||
|
List<ViewModelConfig> viewModelConfigs = Lists.newArrayList(
|
||||||
|
new ViewModelConfig(9L, Lists.newArrayList(16L, 17L, 18L, 19L, 20L), Lists.newArrayList(10L, 11L)),
|
||||||
|
new ViewModelConfig(10L, Lists.newArrayList(21L, 22L, 23L), Lists.newArrayList(12L)),
|
||||||
|
new ViewModelConfig(11L, Lists.newArrayList(), Lists.newArrayList(13L, 14L, 15L)),
|
||||||
|
new ViewModelConfig(12L, Lists.newArrayList(24L), Lists.newArrayList(16L, 17L, 18L, 19L)));
|
||||||
|
|
||||||
|
ViewDetail viewDetail = new ViewDetail();
|
||||||
|
viewDetail.setViewModelConfigs(viewModelConfigs);
|
||||||
|
viewReq.setViewDetail(viewDetail);
|
||||||
|
viewReq.setTypeEnum(TypeEnums.VIEW);
|
||||||
|
QueryConfig queryConfig = new QueryConfig();
|
||||||
|
MetricTypeDefaultConfig metricTypeDefaultConfig = new MetricTypeDefaultConfig();
|
||||||
|
TimeDefaultConfig timeDefaultConfig = new TimeDefaultConfig();
|
||||||
|
timeDefaultConfig.setTimeMode(TimeMode.RECENT);
|
||||||
|
timeDefaultConfig.setUnit(1);
|
||||||
|
metricTypeDefaultConfig.setTimeDefaultConfig(timeDefaultConfig);
|
||||||
|
queryConfig.setMetricTypeDefaultConfig(metricTypeDefaultConfig);
|
||||||
|
viewReq.setQueryConfig(queryConfig);
|
||||||
|
viewService.save(viewReq, User.getFakeUser());
|
||||||
|
}
|
||||||
|
|
||||||
public void addModelRela_1() {
|
public void addModelRela_1() {
|
||||||
List<JoinCondition> joinConditions = Lists.newArrayList();
|
List<JoinCondition> joinConditions = Lists.newArrayList();
|
||||||
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
|
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
|
||||||
|
|||||||
@@ -548,4 +548,4 @@ public class ModelDemoDataLoader {
|
|||||||
return relateDimension;
|
return relateDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user