(improvement) revise evaluation and fix null pointer (#715)

This commit is contained in:
mainmain
2024-02-04 20:16:07 +08:00
committed by GitHub
parent 75853a8e9e
commit da5e7b9b75
29 changed files with 431 additions and 122 deletions

View File

@@ -124,19 +124,23 @@ public class ConfigServiceImpl implements ConfigService {
Long modelId = chatConfig.getModelId(); Long modelId = chatConfig.getModelId();
List<Long> blackDimIdList = new ArrayList<>(); List<Long> blackDimIdList = new ArrayList<>();
if (Objects.nonNull(chatConfig.getChatAggConfig()) && Objects.nonNull(chatConfig.getChatAggConfig())) { if (Objects.nonNull(chatConfig.getChatAggConfig())
&& Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) {
blackDimIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList()); blackDimIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackDimIdList());
} }
if (Objects.nonNull(chatConfig.getChatDetailConfig()) && Objects.nonNull(chatConfig.getChatDetailConfig())) { if (Objects.nonNull(chatConfig.getChatDetailConfig())
&& Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) {
blackDimIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList()); blackDimIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackDimIdList());
} }
List<Long> filterDimIdList = blackDimIdList.stream().distinct().collect(Collectors.toList()); List<Long> filterDimIdList = blackDimIdList.stream().distinct().collect(Collectors.toList());
List<Long> blackMetricIdList = new ArrayList<>(); List<Long> blackMetricIdList = new ArrayList<>();
if (Objects.nonNull(chatConfig.getChatAggConfig()) && Objects.nonNull(chatConfig.getChatAggConfig())) { if (Objects.nonNull(chatConfig.getChatAggConfig())
&& Objects.nonNull(chatConfig.getChatAggConfig().getVisibility())) {
blackMetricIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList()); blackMetricIdList.addAll(chatConfig.getChatAggConfig().getVisibility().getBlackMetricIdList());
} }
if (Objects.nonNull(chatConfig.getChatDetailConfig()) && Objects.nonNull(chatConfig.getChatDetailConfig())) { if (Objects.nonNull(chatConfig.getChatDetailConfig())
&& Objects.nonNull(chatConfig.getChatDetailConfig().getVisibility())) {
blackMetricIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackMetricIdList()); blackMetricIdList.addAll(chatConfig.getChatDetailConfig().getVisibility().getBlackMetricIdList());
} }
List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList()); List<Long> filterMetricIdList = blackMetricIdList.stream().distinct().collect(Collectors.toList());

View File

@@ -1,8 +1,8 @@
# 评测流程 # 评测流程
1. 正常启动项目(必须包括LLM服务) 1. 正常启动项目(必须包括LLM服务)
2. 执行evalution.sh脚本主要包括构建表数据、获取模型预测结果执行对比逻辑。可以在命令行看到执行准确率错误case会写到同目录的error_case.json文件中。 2. 执行evalution.sh脚本主要包括构建表数据、数据建模、获取模型预测结果执行对比逻辑。可以在命令行看到执行准确率错误case会写到同目录的error_case.json文件中。
# 评测意义 # 评测意义
制定评估工具对于提示词或代码更改的影响至关重要,方便supersonic快速对接其他模型、更改配置,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。 制定评估工具方便supersonic快速对接其他模型、更改参数配置,对于提示词或代码更改的影响至关重要,可以帮助我们了解这些变化是否会提高或降低准确率、响应速度。

354
evaluation/build_models.py Normal file
View File

@@ -0,0 +1,354 @@
import sqlite3
import os
import requests
import datetime
import yaml
import json
import time
import jwt
def get_authorization():
exp = time.time() + 1000
token= jwt.encode({"token_user_name": "admin","exp": exp}, "secret", algorithm="HS512")
return "Bearer "+token
def get_url_pre():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
return config["url"]
def get_list(url):
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.get(url=url, headers=header)
json_data=resp.json()
if json_data["code"]==200:
return json_data["data"]
else:
return None
def build_domain():
dict_info={}
json_data='{"name":"DuSQL_互联网企业","bizName":"internet","sensitiveLevel":0,"parentId":0,"isOpen":0,"viewers":["admin","tom","jack"],"viewOrgs":["1"],"admins":["admin"],"adminOrgs":[],"admin":"admin","viewer":"admin,tom,jack","viewOrg":"1","adminOrg":""}'
json_dict=json.loads(json_data)
url=get_url_pre()+"/api/semantic/domain/getDomainList"
domain_list=get_list(url)
build=False
if domain_list is None :
build=True
else:
exist=False
for domain in domain_list:
if domain["bizName"]=="internet":
exist=True
break
if not exist:
build=True
if build:
url=get_url_pre()+"/api/semantic/domain/createDomain"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
url=get_url_pre()+"/api/semantic/domain/getDomainList"
domain_list=get_list(url)
domain_id=domain_list[len(domain_list)-1]["id"]
dict_info["build"]=build
dict_info["domain_id"]=domain_id
return dict_info
def build_model_1(domain_id):
json_data='{"name":"公司","bizName":"company","description":"公司","sensitiveLevel":0,"databaseId":1,"domainId":4,"modelDetail":{"queryType":"sql_query","sqlQuery":"SELECT imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count FROM company","identifiers":[{"name":"公司id","type":"primary","bizName":"company_id","isCreateDimension":0,"fieldName":"company_id"}],"dimensions":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"},{"name":"公司名称","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"company_name","isTag":0,"fieldName":"company_name"},{"name":"总部地点","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"headquarter_address","isTag":0,"fieldName":"headquarter_address"},{"name":"公司成立时间","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"company_established_time","isTag":0,"fieldName":"company_established_time"},{"name":"创始人","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"founder","isTag":0,"fieldName":"founder"},{"name":"首席执行官","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"ceo","isTag":0,"fieldName":"ceo"}],"measures":[{"name":"年营业额","agg":"SUM","bizName":"annual_turnover","isCreateMetric":1},{"name":"员工数","agg":"SUM","bizName":"employee_count","isCreateMetric":1}],"fields":[{"fieldName":"company_id"},{"fieldName":"imp_date"},{"fieldName":"company_established_time"},{"fieldName":"founder"},{"fieldName":"headquarter_address"},{"fieldName":"ceo"},{"fieldName":"company_name"}]},"viewers":["admin","tom","jack"],"viewOrgs":["1"],"admins":["admin"],"adminOrgs":[],"admin":"admin","viewer":"admin,tom,jack","viewOrg":"1","timeDimension":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"}],"adminOrg":""}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
build=False
if model_list is None :
build=True
else:
exist=False
for model in model_list:
if model["bizName"]=="company":
exist=True
break
if not exist:
build=True
if build:
url=get_url_pre()+"/api/semantic/model/createModel"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
model_id=model_list[len(model_list)-1]["id"]
return model_id
def build_model_2(domain_id):
json_data='{"name":"品牌","bizName":"brand","description":"品牌","sensitiveLevel":0,"databaseId":1,"domainId":4,"modelDetail":{"queryType":"sql_query","sqlQuery":"SELECT imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital FROM brand","identifiers":[{"name":"品牌id","type":"primary","bizName":"brand_id","isCreateDimension":0,"fieldName":"brand_id"},{"name":"公司id","type":"foreign","bizName":"company_id","isCreateDimension":0,"fieldName":"company_id"}],"dimensions":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"},{"name":"品牌名称","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"brand_name","isTag":0,"fieldName":"brand_name"},{"name":"品牌成立时间","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"brand_established_time","isTag":0,"fieldName":"brand_established_time"},{"name":"法定代表人","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"legal_representative","isTag":0,"fieldName":"legal_representative"}],"measures":[{"name":"注册资本","agg":"SUM","bizName":"registered_capital","isCreateMetric":1}],"fields":[{"fieldName":"company_id"},{"fieldName":"brand_id"},{"fieldName":"brand_name"},{"fieldName":"imp_date"},{"fieldName":"brand_established_time"},{"fieldName":"legal_representative"}]},"viewers":["admin","tom","jack"],"viewOrgs":["1"],"admins":["admin"],"adminOrgs":[],"admin":"admin","viewer":"admin,tom,jack","viewOrg":"1","timeDimension":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"}],"adminOrg":""}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
build=False
if model_list is None :
build=True
else:
exist=False
for model in model_list:
if model["bizName"]=="brand":
exist=True
break
if not exist:
build=True
if build:
url=get_url_pre()+"/api/semantic/model/createModel"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
model_id=model_list[len(model_list)-1]["id"]
return model_id
def build_model_3(domain_id):
json_data='{"name":"公司各品牌收入排名","bizName":"company_revenue","description":"公司各品牌收入排名","sensitiveLevel":0,"databaseId":1,"domainId":4,"modelDetail":{"queryType":"sql_query","sqlQuery":"SELECT imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion FROM company_revenue","identifiers":[{"name":"公司id","type":"foreign","bizName":"company_id","isCreateDimension":0,"fieldName":"company_id"},{"name":"品牌id","type":"foreign","bizName":"brand_id","isCreateDimension":0,"fieldName":"brand_id"}],"dimensions":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"}],"measures":[{"name":"营收占比","agg":"SUM","bizName":"revenue_proportion","isCreateMetric":1},{"name":"利润占比","agg":"SUM","bizName":"profit_proportion","isCreateMetric":1},{"name":"支出占比","agg":"SUM","bizName":"expenditure_proportion","isCreateMetric":1}],"fields":[{"fieldName":"company_id"},{"fieldName":"brand_id"},{"fieldName":"imp_date"}]},"viewers":["admin","tom","jack"],"viewOrgs":["1"],"admins":["admin"],"adminOrgs":[],"admin":"admin","viewer":"admin,tom,jack","viewOrg":"1","timeDimension":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"}],"adminOrg":""}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
build=False
if model_list is None :
build=True
else:
exist=False
for model in model_list:
if model["bizName"]=="company_revenue":
exist=True
break
if not exist:
build=True
if build:
url=get_url_pre()+"/api/semantic/model/createModel"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
model_id=model_list[len(model_list)-1]["id"]
return model_id
def build_model_4(domain_id):
json_data='{"name":"公司品牌历年收入","bizName":"company_brand_revenue","description":"公司品牌历年收入","sensitiveLevel":0,"databaseId":1,"domainId":4,"modelDetail":{"queryType":"sql_query","sqlQuery":"SELECT imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue","identifiers":[{"name":"品牌id","type":"foreign","bizName":"brand_id","isCreateDimension":0,"fieldName":"brand_id"}],"dimensions":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"},{"name":"年份","type":"categorical","dateFormat":"yyyy-MM-dd","isCreateDimension":1,"bizName":"year_time","isTag":0,"fieldName":"year_time"}],"measures":[{"name":"营收","agg":"SUM","bizName":"revenue","isCreateMetric":1},{"name":"利润","agg":"SUM","bizName":"profit","isCreateMetric":1},{"name":"营收同比增长","agg":"SUM","bizName":"revenue_growth_year_on_year","isCreateMetric":1},{"name":"利润同比增长","agg":"SUM","bizName":"profit_growth_year_on_year","isCreateMetric":1}],"fields":[{"fieldName":"brand_id"},{"fieldName":"imp_date"},{"fieldName":"year_time"}]},"viewers":["admin","tom","jack"],"viewOrgs":["1"],"admins":["admin"],"adminOrgs":[],"admin":"admin","viewer":"admin,tom,jack","viewOrg":"1","timeDimension":[{"name":"","type":"time","dateFormat":"yyyy-MM-dd","typeParams":{"isPrimary":"false","timeGranularity":"none"},"isCreateDimension":0,"bizName":"imp_date","isTag":0,"fieldName":"imp_date"}],"adminOrg":""}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
build=False
if model_list is None :
build=True
else:
exist=False
for model in model_list:
if model["bizName"]=="company_brand_revenue":
exist=True
break
if not exist:
build=True
if build:
url=get_url_pre()+"/api/semantic/model/createModel"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
url=get_url_pre()+"/api/semantic/model/getModelList/"+str(domain_id)
model_list=get_list(url)
model_id=model_list[len(model_list)-1]["id"]
return model_id
def build_model_rela1(domain_id,from_model_id,to_model_id):
json_data='{"domainId":4,"fromModelId":9,"toModelId":10,"joinType":"inner join","joinConditions":[{"leftField":"company_id","rightField":"company_id","operator":"="}]}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
json_dict["fromModelId"]=from_model_id
json_dict["toModelId"]=to_model_id
url=get_url_pre()+"/api/semantic/modelRela"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
def build_model_rela2(domain_id,from_model_id,to_model_id):
json_data='{"domainId":4,"fromModelId":9,"toModelId":11,"joinType":"inner join","joinConditions":[{"leftField":"company_id","rightField":"company_id","operator":"="}]}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
json_dict["fromModelId"]=from_model_id
json_dict["toModelId"]=to_model_id
url=get_url_pre()+"/api/semantic/modelRela"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
def build_model_rela3(domain_id,from_model_id,to_model_id):
json_data='{"domainId":4,"fromModelId":10,"toModelId":11,"joinType":"inner join","joinConditions":[{"leftField":"brand_id","rightField":"brand_id","operator":"="}]}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
json_dict["fromModelId"]=from_model_id
json_dict["toModelId"]=to_model_id
url=get_url_pre()+"/api/semantic/modelRela"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
def build_model_rela4(domain_id,from_model_id,to_model_id):
json_data='{"domainId":4,"fromModelId":10,"toModelId":12,"joinType":"inner join","joinConditions":[{"leftField":"brand_id","rightField":"brand_id","operator":"="}]}'
json_dict=json.loads(json_data)
json_dict["domainId"]=domain_id
json_dict["fromModelId"]=from_model_id
json_dict["toModelId"]=to_model_id
url=get_url_pre()+"/api/semantic/modelRela"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
def get_id_list(data_list):
id_list=[]
if data_list is not None:
for data in data_list:
id_list.append(data["id"])
return id_list
def build_view(domain_id,model_id1,model_id2,model_id3,model_id4):
url=get_url_pre()+"/api/semantic/dimension/getDimensionList/"+str(model_id1)
dimension_list1=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/dimension/getDimensionList/"+str(model_id2)
dimension_list2=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/dimension/getDimensionList/"+str(model_id3)
dimension_list3=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/dimension/getDimensionList/"+str(model_id4)
dimension_list4=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/metric/getMetricList/"+str(model_id1)
metric_list1=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/metric/getMetricList/"+str(model_id2)
metric_list2=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/metric/getMetricList/"+str(model_id3)
metric_list3=get_id_list(get_list(url))
url=get_url_pre()+"/api/semantic/metric/getMetricList/"+str(model_id4)
metric_list4=get_id_list(get_list(url))
json_dict={"name":"DuSQL 互联网企业","bizName":"internet","description":"DuSQL互联网企业数据源相关的指标和维度等",
"typeEnum":"VIEW","sensitiveLevel":0,"domainId":domain_id,"viewDetail":
{"viewModelConfigs":[{"id":model_id1,"includesAll":False,"metrics":metric_list1,
"dimensions":dimension_list1},{"id":model_id2,"includesAll":False,
"metrics":metric_list2,"dimensions":dimension_list2},{"id":model_id3,"includesAll":False,"metrics":metric_list3,"dimensions":dimension_list3},
{"id":model_id4,"includesAll":False,"metrics":metric_list4,"dimensions":dimension_list4}]},"queryConfig":{"tagTypeDefaultConfig":
{"dimensionIds":[],"metricIds":[]},"metricTypeDefaultConfig":{"timeDefaultConfig":{"unit":1,"period":"DAY","timeMode":"RECENT"}}},"admins":["admin"],"admin":"admin"}
url=get_url_pre()+"/api/semantic/view"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
url=get_url_pre()+"/api/semantic/view/getViewList?domainId="+str(domain_id)
print(url)
resp=get_list(url)
data={}
data["id"]=resp[0]["id"]
dim={}
dim[model_id1]=dimension_list1
dim[model_id2]=dimension_list2
dim[model_id3]=dimension_list3
dim[model_id4]=dimension_list4
data["dim"]=dim
return data
def build_agent(view_id):
json_dict={
"id":10,
"enableSearch":1,
"name":"DuSQL 互联网企业",
"description":"DuSQL",
"status":1,
"examples":[],
"agentConfig":json.dumps({
"tools":[{
"id":1,
"type":"NL2SQL_LLM",
"viewIds":[view_id]
}]
})
}
url=get_url_pre()+"/api/chat/agent"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header,json=json_dict)
def build_chat(agentId):
url=get_url_pre()+"/api/chat/manage/save?chatName=DuSQL问答&agentId="+str(agentId)
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
resp=requests.post(url=url, headers=header)
url=get_url_pre()+"/api/chat/manage/getAll?agentId="+str(agentId)
data=get_list(url)
return data[0]["chatId"]
def build_dim_value_dict(modelIds,info):
url=get_url_pre()+"/api/chat/dict/task"
authorization=get_authorization()
header = {}
header["Authorization"] =authorization
data={
"updateMode":"REALTIME_ADD",
"modelIds":modelIds,
"modelAndDimPair":info["dim"]
}
print(data)
resp=requests.post(url=url, headers=header,json=data)
def build():
dict_info=build_domain()
domain_id=dict_info["domain_id"]
if dict_info["build"]:
model_id1=build_model_1(domain_id)
model_id2=build_model_2(domain_id)
model_id3=build_model_3(domain_id)
model_id4=build_model_4(domain_id)
view_id=build_view(domain_id,model_id1,model_id2,model_id3,model_id4)
build_model_rela1(domain_id,model_id1,model_id2)
build_model_rela2(domain_id,model_id1,model_id3)
build_model_rela3(domain_id,model_id2,model_id3)
build_model_rela4(domain_id,model_id2,model_id4)
build_agent(view_id["id"])
agentId=10
chat_id=build_chat(agentId)
dict={}
dict["agent_id"]=agentId
dict["chat_id"]=chat_id
else:
agentId=10
chat_id=build_chat(agentId)
dict={}
dict["agent_id"]=agentId
dict["chat_id"]=chat_id
return dict
if __name__ == '__main__':
dict_info=build()
print(dict_info)

View File

@@ -1,10 +1,11 @@
import time
import requests import requests
import logging import logging
import json import json
import jwt
import time
import os import os
import yaml import yaml
from build_models import build,get_authorization
def read_query(input_path): def read_query(input_path):
result=[] result=[]
@@ -24,13 +25,16 @@ def get_pred_sql(query,url,agentId,chatId,authorization,default_sql):
header["Authorization"] =authorization header["Authorization"] =authorization
try: try:
result = requests.post(url=url, headers=header, json=data) result = requests.post(url=url, headers=header, json=data)
print(result.json())
print(result.json()["traceId"])
if result.status_code == 200: if result.status_code == 200:
data = result.json()["data"] data = result.json()["data"]
selectedParses = data["selectedParses"] selectedParses = data["selectedParses"]
if selectedParses is not None and len(selectedParses) > 0: if selectedParses is not None and len(selectedParses) > 0:
querySQL = selectedParses[0]["sqlInfo"]["querySQL"] querySQL = selectedParses[0]["sqlInfo"]["querySQL"]
querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "") if querySQL is not None:
return querySQL+'\n' querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "")
return querySQL+'\n'
return default_sql+'\n' return default_sql+'\n'
except Exception as e: except Exception as e:
print(url) print(url)
@@ -38,24 +42,22 @@ def get_pred_sql(query,url,agentId,chatId,authorization,default_sql):
print(e) print(e)
logging.info(e) logging.info(e)
return default_sql+'\n' return default_sql+'\n'
def get_authorization():
exp = time.time() + 1000
token= jwt.encode({"token_user_name": "admin","exp": exp}, "secret", algorithm="HS512")
return "Bearer "+token
def get_pred_result(): def get_pred_result():
current_directory = os.path.dirname(os.path.abspath(__file__)) current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml" config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file: with open(config_file, 'r') as file:
config = yaml.safe_load(file) config = yaml.safe_load(file)
input_path=current_directory+"/data/"+"internet.txt" input_path=current_directory+"/data/internet.txt"
pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt" pred_sql_path = current_directory+"/data/pred_example_dusql.txt"
pred_sql_exist=os.path.exists(pred_sql_path) pred_sql_exist=os.path.exists(pred_sql_path)
if pred_sql_exist: if pred_sql_exist:
os.remove(pred_sql_path) os.remove(pred_sql_path)
print("pred_sql_path removed!") print("pred_sql_path removed!")
agent_id=config["agent_id"] dict_info=build()
chat_id=config["chat_id"] print(dict_info)
agent_id=dict_info["agent_id"]
chat_id=dict_info["chat_id"]
url=config["url"] url=config["url"]
authorization=get_authorization() authorization=get_authorization()
print(input_path) print(input_path)
@@ -66,6 +68,7 @@ def get_pred_result():
for i in range(0,len(questions)): for i in range(0,len(questions)):
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql) pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
pred_sql_list.append(pred_sql) pred_sql_list.append(pred_sql)
time.sleep(30)
write_sql(pred_sql_path, pred_sql_list) write_sql(pred_sql_path, pred_sql_list)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,3 +1 @@
chat_id: 3
agent_id: 4
url: http://localhost:9080 url: http://localhost:9080

View File

@@ -1,17 +0,0 @@
[
{
"query": "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司的年营业额",
"gold_sql": "SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
"pred_sql": "select * from tablea"
},
{
"query": "在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例",
"gold_sql": "SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
"pred_sql": "select * from tablea"
},
{
"query": "在各公司所有品牌收入排名中,给出每一个品牌和其法人,其所在公司以及收入占该公司的总收入比例",
"gold_sql": "SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id",
"pred_sql": "select * from tablea"
}
]

View File

@@ -51,4 +51,4 @@ public class DimensionDO {
private String dataType; private String dataType;
private int isTag; private int isTag;
} }

View File

@@ -81,4 +81,4 @@ public class DomainDO {
*/ */
private String viewOrg; private String viewOrg;
} }

View File

@@ -80,17 +80,17 @@ public class MetricDO {
private String dataFormat; private String dataFormat;
/** /**
* *
*/ */
private String alias; private String alias;
/** /**
* *
*/ */
private String tags; private String tags;
/** /**
* *
*/ */
private String relateDimensions; private String relateDimensions;
@@ -103,4 +103,4 @@ public class MetricDO {
private String defineType; private String defineType;
} }

View File

@@ -56,4 +56,4 @@ public class ModelDO {
private String sourceType; private String sourceType;
} }

View File

@@ -48,10 +48,11 @@ public class StatRepositoryImpl implements StatRepository {
statInfos.stream().forEach(stat -> { statInfos.stream().forEach(stat -> {
String dimensions = stat.getDimensions(); String dimensions = stat.getDimensions();
String metrics = stat.getMetrics(); String metrics = stat.getMetrics();
updateStatMapInfo(map, dimensions, TypeEnums.DIMENSION.name().toLowerCase(), stat.getModelId()); if (Objects.nonNull(stat.getViewId())) {
updateStatMapInfo(map, metrics, TypeEnums.METRIC.name().toLowerCase(), stat.getModelId()); updateStatMapInfo(map, dimensions, TypeEnums.DIMENSION.name().toLowerCase(), stat.getViewId());
updateStatMapInfo(map, metrics, TypeEnums.METRIC.name().toLowerCase(), stat.getViewId());
}
}); });
map.forEach((k, v) -> { map.forEach((k, v) -> {
Long classId = Long.parseLong(k.split(AT_SYMBOL + AT_SYMBOL)[0]); Long classId = Long.parseLong(k.split(AT_SYMBOL + AT_SYMBOL)[0]);
String type = k.split(AT_SYMBOL + AT_SYMBOL)[1]; String type = k.split(AT_SYMBOL + AT_SYMBOL)[1];
@@ -68,13 +69,13 @@ public class StatRepositoryImpl implements StatRepository {
return statMapper.getStatInfo(itemUseCommend); return statMapper.getStatInfo(itemUseCommend);
} }
private void updateStatMapInfo(Map<String, Long> map, String dimensions, String type, Long modelId) { private void updateStatMapInfo(Map<String, Long> map, String dimensions, String type, Long viewId) {
if (Strings.isNotEmpty(dimensions)) { if (Strings.isNotEmpty(dimensions)) {
try { try {
List<String> dimensionList = mapper.readValue(dimensions, new TypeReference<List<String>>() { List<String> dimensionList = mapper.readValue(dimensions, new TypeReference<List<String>>() {
}); });
dimensionList.stream().forEach(dimension -> { dimensionList.stream().forEach(dimension -> {
String key = modelId + AT_SYMBOL + AT_SYMBOL + type + AT_SYMBOL + AT_SYMBOL + dimension; String key = viewId + AT_SYMBOL + AT_SYMBOL + type + AT_SYMBOL + AT_SYMBOL + dimension;
if (map.containsKey(key)) { if (map.containsKey(key)) {
map.put(key, map.get(key) + 1); map.put(key, map.get(key) + 1);
} else { } else {
@@ -97,4 +98,4 @@ public class StatRepositoryImpl implements StatRepository {
} }
} }
} }
} }

View File

@@ -103,6 +103,7 @@ public class BenchMarkDemoDataLoader {
dimension1.setTypeParams(new DimensionTimeTypeParams()); dimension1.setTypeParams(new DimensionTimeTypeParams());
dimensions.add(dimension1); dimensions.add(dimension1);
dimensions.add(new Dim("活跃区域", "most_popular_in", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("活跃区域", "most_popular_in", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("音乐类型名称", "g_name", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>(); List<Identify> identifiers = new ArrayList<>();
@@ -129,6 +130,7 @@ public class BenchMarkDemoDataLoader {
modelReq.setDatabaseId(1L); modelReq.setDatabaseId(1L);
ModelDetail modelDetail = new ModelDetail(); ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>(); List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("艺术家名称", "artist_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("国籍", "country", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("国籍", "country", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("性别", "gender", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("性别", "gender", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
@@ -157,6 +159,7 @@ public class BenchMarkDemoDataLoader {
List<Dim> dimensions = new ArrayList<>(); List<Dim> dimensions = new ArrayList<>();
dimensions.add(new Dim("持续时间", "duration", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("持续时间", "duration", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("文件格式", "formats", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("文件格式", "formats", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("艺术家名称", "artist_name", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions); modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>(); List<Identify> identifiers = new ArrayList<>();
@@ -184,6 +187,7 @@ public class BenchMarkDemoDataLoader {
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
dimension1.setTypeParams(new DimensionTimeTypeParams()); dimension1.setTypeParams(new DimensionTimeTypeParams());
dimensions.add(dimension1); dimensions.add(dimension1);
dimensions.add(new Dim("歌曲名称", "song_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("国家", "country", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("国家", "country", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("语种", "languages", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("语种", "languages", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("发行时间", "releasedate", DimensionType.categorical.name(), 1)); dimensions.add(new Dim("发行时间", "releasedate", DimensionType.categorical.name(), 1));

View File

@@ -74,7 +74,7 @@ public class ChatDemoLoader implements CommandLineRunner {
addAgent1(); addAgent1();
addAgent2(); addAgent2();
addAgent3(); addAgent3();
addAgent4(); //addAgent4();
addSampleChats(); addSampleChats();
addSampleChats2(); addSampleChats2();
updateQueryScore(1); updateQueryScore(1);
@@ -248,6 +248,7 @@ public class ChatDemoLoader implements CommandLineRunner {
} }
agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
log.info("agent:{}", JsonUtil.toString(agent));
agentService.createAgent(agent, User.getFakeUser()); agentService.createAgent(agent, User.getFakeUser());
} }

View File

@@ -39,7 +39,7 @@ public class HeadlessDemoLoader implements CommandLineRunner {
} }
modelDataDemoLoader.doRun(); modelDataDemoLoader.doRun();
benchMarkDemoLoader.doRun(); benchMarkDemoLoader.doRun();
duSQLDemoDataLoader.doRun(); //duSQLDemoDataLoader.doRun();
isLoad = true; isLoad = true;
} }

View File

@@ -1,8 +0,0 @@
阿里云 _10_20 5
天猫 _10_20 5
腾讯游戏 _10_20 5
度小满 _10_20 5
京东金融 _10_20 5

View File

@@ -1,8 +0,0 @@
张勇 _10_22 5
马化腾 _10_22 5
朱光 _10_22 5
刘强东 _10_22 5

View File

@@ -1,4 +1,4 @@
美国 _5_8 1 美国 _3_8 1
加拿大 _5_8 1 加拿大 _3_8 1
锡尔赫特、吉大港、库斯蒂亚 _5_8 1 锡尔赫特、吉大港、库斯蒂亚 _3_8 1
孟加拉国 _5_8 3 孟加拉国 _3_8 3

View File

@@ -1,6 +1,6 @@
现代 _5_9 1 现代 _3_9 1
tagore _5_9 1 tagore _3_9 1
蓝调 _5_9 1 蓝调 _3_9 1
流行 _5_9 1 流行 _3_9 1
民间 _5_9 1 民间 _3_9 1
nazrul _5_9 1 nazrul _3_9 1

View File

@@ -1,4 +1,4 @@
美国 _6_10 1 美国 _3_11 1
印度 _6_10 2 印度 _3_11 2
英国 _6_10 1 英国 _3_11 1
孟加拉国 _6_10 2 孟加拉国 _3_11 2

View File

@@ -1,2 +1,2 @@
男性 _6_11 3 男性 _3_12 3
女性 _6_11 3 女性 _3_12 3

View File

@@ -1,2 +1,2 @@
mp4 _7_14 4 mp4 _3_14 4
mp3 _7_14 2 mp3 _3_14 2

View File

@@ -1,4 +1,4 @@
美国 _8_16 1 美国 _3_17 1
印度 _8_16 2 印度 _3_17 2
英国 _8_16 1 英国 _3_17 1
孟加拉国 _8_16 2 孟加拉国 _3_17 2

View File

@@ -1,2 +1,2 @@
英文 _8_17 2 英文 _3_18 2
孟加拉语 _8_17 4 孟加拉语 _3_18 4

View File

@@ -1,6 +1,6 @@
阿米·奥帕尔·霍伊 _8_19 1 阿米·奥帕尔·霍伊 _3_16 1
我的爱 _8_19 1 我的爱 _3_16 1
打败它 _8_19 1 打败它 _3_16 1
阿杰伊阿卡什 _8_19 1 阿杰伊阿卡什 _3_16 1
Tumi#长袍#尼罗布 _8_19 1 Tumi#长袍#尼罗布 _3_16 1
舒克诺#帕塔尔#努普尔#帕埃 _8_19 1 舒克诺#帕塔尔#努普尔#帕埃 _3_16 1

View File

@@ -1,5 +0,0 @@
百度集团 _9_15 5
阿里巴巴集团 _9_15 5
深圳市腾讯计算机系统有限公司 _9_15 5
北京京东世纪贸易有限公司 _9_15 5
网易公司 _9_15 5

View File

@@ -1,4 +0,0 @@
北京 _9_16 5
杭州 _9_16 5
深圳 _9_16 5

View File

@@ -1,7 +0,0 @@
李彦宏 _9_18 5
马云 _9_18 5
马化腾 _9_18 5
刘强东 _9_18 5
丁磊 _9_18 5

View File

@@ -1,7 +0,0 @@
李彦宏 _9_19 5
张勇 _9_19 5
刘炽平 _9_19 5
刘强东 _9_19 5
丁磊 _9_19 5

View File

@@ -1,2 +1,2 @@
root=. root=.
CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_9_15.txt;data/dictionary/custom/DimValue_9_16.txt;data/dictionary/custom/DimValue_9_18.txt;data/dictionary/custom/DimValue_9_19.txt;data/dictionary/custom/DimValue_10_20.txt;data/dictionary/custom/DimValue_10_22.txt; CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;