[improvement][supersonic] add text-to-sql evaluation (#696)

* [improvement] llm supports all models

* [improvement] alias convert to SemanticParseInfo

* [improvement] support join

* [improvement] add evaluation.py

* [improvement] add text2sql_evalution.py

* [improvement] add text2sql_evalution.py

* [improvement] add evalution

* [improvement] add evalution

* [improvement] add evalution

---------

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2024-01-30 10:46:45 +08:00
committed by GitHub
parent aae3d6b297
commit c398ac1a84
29 changed files with 3347 additions and 15 deletions

View File

@@ -25,4 +25,4 @@ LLM_PROVIDER_NAME = openai
[LLMModel]
MODEL_NAME = gpt-3.5-turbo-16k
OPENAI_API_KEY = YOUR_API_KEY
TEMPERATURE = 0.0
TEMPERATURE = 0.0

View File

@@ -6,4 +6,4 @@ tiktoken==0.3.3
uvicorn[standard]==0.21.1
pandas==1.5.3
loguru==0.7.2
sqlglot==19.5.1
sqlglot==19.5.1

View File

@@ -102,8 +102,18 @@ public class ParseInfoProcessor implements ParseResultProcessor {
private Set<SchemaElement> getElements(Set<Long> modelIds, List<String> allFields, List<SchemaElement> elements) {
return elements.stream()
.filter(schemaElement -> modelIds.contains(schemaElement.getModel()) && allFields.contains(
schemaElement.getName())
.filter(schemaElement -> {
if (CollectionUtils.isEmpty(schemaElement.getAlias())) {
return modelIds.contains(schemaElement.getModel()) && allFields.contains(
schemaElement.getName());
}
Set<String> allFieldsSet = new HashSet<>(allFields);
Set<String> aliasSet = new HashSet<>(schemaElement.getAlias());
List<String> intersection = allFieldsSet.stream()
.filter(aliasSet::contains).collect(Collectors.toList());
return modelIds.contains(schemaElement.getModel()) && (allFields.contains(
schemaElement.getName()) || !CollectionUtils.isEmpty(intersection));
}
).collect(Collectors.toSet());
}
@@ -208,4 +218,4 @@ public class ParseInfoProcessor implements ParseResultProcessor {
(value1, value2) -> value2));
}
}
}

View File

@@ -23,6 +23,7 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;
import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.select.GroupByElement;
import net.sf.jsqlparser.statement.select.Join;
import net.sf.jsqlparser.statement.select.OrderByElement;
@@ -231,6 +232,9 @@ public class SqlParserReplaceHelper {
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
join.getOnExpression().accept(visitor);
if (!(join.getRightItem() instanceof SubSelect)) {
continue;
}
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
List<PlainSelect> plainSelectList = new ArrayList<>();
plainSelectList.add((PlainSelect) subSelectBody);
@@ -414,12 +418,17 @@ public class SqlParserReplaceHelper {
List<Join> joins = painSelect.getJoins();
if (!CollectionUtils.isEmpty(joins)) {
for (Join join : joins) {
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
List<PlainSelect> plainSelectList = new ArrayList<>();
plainSelectList.add((PlainSelect) subSelectBody);
List<PlainSelect> subPlainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect subPlainSelect : subPlainSelects) {
subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
if (join.getRightItem() instanceof SubSelect) {
SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody();
List<PlainSelect> plainSelectList = new ArrayList<>();
plainSelectList.add((PlainSelect) subSelectBody);
List<PlainSelect> subPlainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList);
for (PlainSelect subPlainSelect : subPlainSelects) {
subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName));
}
} else if (join.getRightItem() instanceof Table) {
Table table = (Table) join.getRightItem();
table.setName(tableName);
}
}
}

5
evaluation/README_CN.md Normal file
View File

@@ -0,0 +1,5 @@
# 评测流程
1、正常启动项目(必须包括LLM服务)
2、将要评测问题放到evalution/data目录下internet.txt将要评测问题对应的SQL也放到evalution/data目录下gold_example_dusql.txt。
3、执行evalution.sh脚本主要包括构建表数据、获取模型预测结果执行对比逻辑。可以在命令行看到执行准确率错误case会写到同目录的eval.json文件中。

View File

@@ -0,0 +1,74 @@
import requests
import logging
import json
import jwt
import time
import os
import yaml
def read_query(input_path):
result=[]
with open(input_path, "r") as f:
for line in f.readlines():
line = line.strip('\n')
result.append(line)
return result
def write_sql(output_path,result):
file = open(output_path, mode='a')
file.writelines(result)
file.close()
def get_pred_sql(query,url,agentId,chatId,authorization,default_sql):
url=url+"/api/chat/query/parse"
data = {"agentId": agentId, "chatId":chatId,"queryText":query}
header = {}
header["Authorization"] =authorization
try:
result = requests.post(url=url, headers=header, json=data)
if result.status_code == 200:
data = result.json()["data"]
selectedParses = data["selectedParses"]
if selectedParses is not None and len(selectedParses) > 0:
querySQL = selectedParses[0]["sqlInfo"]["querySQL"]
querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "")
return querySQL+'\n'
return default_sql+'\n'
except Exception as e:
print(url)
print(result.json())
print(e)
logging.info(e)
return default_sql+'\n'
def get_authorization():
exp = time.time() + 1000
token= jwt.encode({"token_user_name": "admin","exp": exp}, "secret", algorithm="HS512")
return "Bearer "+token
def get_pred_result():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
input_path=current_directory+"/data/"+config["domain"]+".txt"
pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt"
pred_sql_exist=os.path.exists(pred_sql_path)
if pred_sql_exist:
os.remove(pred_sql_path)
print("pred_sql_path removed!")
agent_id=config["agent_id"]
chat_id=config["chat_id"]
url=config["url"]
authorization=get_authorization()
print(input_path)
print(pred_sql_path)
questions=read_query(input_path)
pred_sql_list=[]
default_sql="select * from tablea "
for i in range(0,len(questions)):
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
pred_sql_list.append(pred_sql)
write_sql(pred_sql_path, pred_sql_list)
if __name__ == "__main__":
print("pred")

340
evaluation/build_tables.py Normal file
View File

@@ -0,0 +1,340 @@
import sqlite3
import os
import datetime
import yaml
def build_internet(path,day):
imp_date=(datetime.datetime.now()+datetime.timedelta(days=day)).strftime("%Y-%m-%d")
print(imp_date)
conn = sqlite3.connect(path+'/internet.db')
cursor = conn.cursor()
create_table_query = '''
CREATE TABLE IF NOT EXISTS company (
`imp_date` varchar(50) ,
`company_id` varchar(50) NOT NULL ,
`company_name` varchar(50) NOT NULL ,
`headquarter_address` varchar(50) NOT NULL ,
`company_established_time` varchar(20) NOT NULL ,
`founder` varchar(20) NOT NULL ,
`ceo` varchar(20) NOT NULL ,
`annual_turnover` bigint(15) ,
`employee_count` int(7) ,
PRIMARY KEY (`company_id`)
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO company (imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count)
VALUES (?, ?, ?,?, ?, ?,?, ?, ?)
'''
data = [
(imp_date,"item_enterprise_13_131","百度集团","北京","2000","李彦宏","李彦宏",102300000000,40000),
(imp_date,"item_enterprise_13_132","阿里巴巴集团","杭州","1999年","马云","张勇",376800000000,103699),
(imp_date,"item_enterprise_13_133","深圳市腾讯计算机系统有限公司","深圳","1998","马化腾","刘炽平",321600000000,56310),
(imp_date,"item_enterprise_13_134","北京京东世纪贸易有限公司","北京","1998","刘强东","刘强东",28800000000,179000),
(imp_date,"item_enterprise_13_135","网易公司","杭州","1997","丁磊","丁磊",67500000000,20000)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS brand (
`imp_date` varchar(50) ,
`brand_id` varchar(50) NOT NULL ,
`brand_name` varchar(50) NOT NULL ,
`brand_established_time` varchar(20) NOT NULL ,
`company_id` varchar(50) NOT NULL ,
`legal_representative` varchar(20) NOT NULL ,
`registered_capital` bigint(15) ,
PRIMARY KEY (`brand_id`)
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO brand (imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital)
VALUES (?, ?, ?,?, ?, ?,?)
'''
data = [
(imp_date,"item_enterprise_13_136","阿里云","2009年9月10日","item_enterprise_13_134","张勇",50000000),
(imp_date,"item_enterprise_13_137","天猫","2012年1月11日","item_enterprise_13_134","张勇",100000000),
(imp_date,"item_enterprise_13_138","腾讯游戏","2003","item_enterprise_13_131","马化腾",50000000),
(imp_date,"item_enterprise_13_139","度小满","2018","item_enterprise_13_132","朱光",100000000),
(imp_date,"item_enterprise_13_140","京东金融","2017","item_enterprise_13_134","刘强东",100000000)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS company_revenue (
`imp_date` varchar(50) ,
`company_id` varchar(50) NOT NULL ,
`brand_id` varchar(50) NOT NULL ,
`revenue_proportion` double NOT NULL,
`profit_proportion` double NOT NULL ,
`expenditure_proportion` double NOT NULL
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO company_revenue (imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion)
VALUES (?, ?, ?,?, ?, ?)
'''
data = [
(imp_date,"item_enterprise_13_131","item_enterprise_13_139",0.1,0.1,0.3),
(imp_date,"item_enterprise_13_134","item_enterprise_13_138",0.8,0.8,0.6),
(imp_date,"item_enterprise_13_135","item_enterprise_13_139",0.8,0.8,0.6),
(imp_date,"item_enterprise_13_131","item_enterprise_13_137",0.8,0.8,0.6),
(imp_date,"item_enterprise_13_135","item_enterprise_13_137",0.1,0.1,0.3)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS company_brand_revenue (
`imp_date` varchar(50) ,
`year_time` varchar(10) NOT NULL ,
`brand_id` varchar(50) NOT NULL ,
`revenue` bigint(15) NOT NULL,
`profit` bigint(15) NOT NULL ,
`revenue_growth_year_on_year` double NOT NULL ,
`profit_growth_year_on_year` double NOT NULL
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO company_brand_revenue (imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year)
VALUES (?, ?, ?,?, ?, ?,?)
'''
data = [
(imp_date, "2018", "item_enterprise_13_138", 500000000, -300000000, 0.1, -0.1),
(imp_date, "2019", "item_enterprise_13_136", 100000000000, 50000000000, 1, 0.5),
(imp_date, "2018", "item_enterprise_13_137", 100000000000, 50000000000, 1, -0.1),
(imp_date, "2018", "item_enterprise_13_139", 500000000, 50000000000, 0.1, 0.5),
(imp_date, "2018", "item_enterprise_13_138", 100000000000, -300000000, 0.1, 0.5)
]
cursor.executemany(insert_data_query, data)
conn.commit()
conn.close()
def build_china_travel_agency(path,day):
imp_date=(datetime.datetime.now()+datetime.timedelta(days=day)).strftime("%Y-%m-%d")
print(imp_date)
conn = sqlite3.connect(path+'/china_travel_agency.db')
cursor = conn.cursor()
create_table_query = '''
CREATE TABLE IF NOT EXISTS `travel_agency` (
`imp_date` varchar(50) ,
`travel_agency_id` varchar(50) NOT NULL,
`travel_agency_name` varchar(50) NOT NULL,
`travel_agency_level` varchar(50) NOT NULL,
`number_countrie_outbound_travel` int(7) ,
`number_domestic_tourist_cities` int(7) ,
`number_outbound_travel_routes` int(7) ,
`number_domestic_travel_routes` int(7) ,
`asia_ranking` int(7) ,
`number_overseas_tourists_received` int(7) ,
`number_overseas_companies` int(7) ,
`number_holding_subsidiaries` int(7) ,
`number_traveling_salesmen_business_relationships` int(7) ,
`number_duty_free_shops` int(7) ,
PRIMARY KEY (`travel_agency_id`)
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO travel_agency (imp_date, travel_agency_id,travel_agency_name,travel_agency_level,number_countrie_outbound_travel,
number_domestic_tourist_cities,number_outbound_travel_routes,number_domestic_travel_routes,
asia_ranking,number_overseas_tourists_received,number_overseas_companies,number_holding_subsidiaries,
number_traveling_salesmen_business_relationships,number_duty_free_shops)
VALUES (?, ?, ?,?, ?, ?,?, ?, ?,?, ?, ?, ?, ?)
'''
data = [
(imp_date,"item_enterprise_7_56","中国国旅","3A",50,30,500,1000,1,10000000,5,70,5000,100),
(imp_date,"item_enterprise_7_57","中旅总社","4A",90,100,1000,2000,100,20000000,20,120,8000,180),
(imp_date,"item_enterprise_7_58","中青旅控股股份有限公司","5A",50,30,500,2000,1,10000000,5,70,5000,100),
(imp_date,"item_enterprise_7_59","中国康辉旅游集团有限公司","5A",50,30,1000,1000,100,10000000,5,70,8000,100),
(imp_date,"item_enterprise_7_60","众信旅游集团股份有限公司","5A",90,30,1000,2000,1,20000000,20,70,8000,180)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS `outbound_travel_routes` (
`imp_date` varchar(50) ,
`outbound_route_id` varchar(50) NOT NULL ,
`outbound_route_name` varchar(50) NOT NULL ,
`travel_agency_id` varchar(50) NOT NULL ,
`outbound_departure_city` varchar(50) NOT NULL ,
`outbound_days` bigint(5) ,
`adult_price` bigint(5) ,
`child_price` bigint(5) ,
`countries` bigint(5) ,
`attractions` bigint(5) ,
`total_ticket_price` bigint(5) ,
PRIMARY KEY (`outbound_route_id`)
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO outbound_travel_routes (imp_date, outbound_route_id,outbound_route_name,travel_agency_id,
outbound_departure_city,outbound_days,adult_price,child_price,countries,attractions,total_ticket_price)
VALUES (?, ?, ?,?, ?, ?,?,?, ?, ?,?)
'''
data = [
(imp_date,"item_enterprise_7_61","德法瑞意深度纵览一价无忧","item_enterprise_7_59","北京",12,10900,8900,5,15,750),
(imp_date,"item_enterprise_7_62","意大利全景+西西里精华深度纵览","item_enterprise_7_59","天津",20,18900,15900,10,25,2500),
(imp_date,"item_enterprise_7_63","悦选意大利经典大城小镇书香之旅","item_enterprise_7_57","上海",20,10900,8900,5,15,2500),
(imp_date,"item_enterprise_7_64","新西兰南岛双冰川双峡湾深度纯净之旅","item_enterprise_7_59","哈尔滨",12,18900,8900,5,15,2500),
(imp_date,"item_enterprise_7_65","英国+爱尔兰+威尔士精选之旅","item_enterprise_7_57","深圳",12,18900,15900,10,15,750)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS `country_outbound_travel` (
`imp_date` varchar(50) ,
`outbound_travel_route_id` varchar(50) NOT NULL ,
`nation` varchar(20) NOT NULL ,
`travel_days` int(6) NOT NULL ,
`outbound_number_attractions` int(6) NOT NULL
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO country_outbound_travel (imp_date, outbound_travel_route_id,nation,travel_days,outbound_number_attractions)
VALUES (?, ?, ?,?, ?)
'''
data = [
(imp_date,"item_enterprise_7_64","意大利",2,3),
(imp_date,"item_enterprise_7_63","法国",4,5),
(imp_date,"item_enterprise_7_62","德国",4,5),
(imp_date,"item_enterprise_7_65","瑞士",4,3),
(imp_date,"item_enterprise_7_61","英格兰",4,3)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS `domestic_travel_routes` (
`imp_date` varchar(50) ,
`domestic_travel_route_id` varchar(50) NOT NULL ,
`domestic_travel_route_name` varchar(50) NOT NULL ,
`travel_agency_id` varchar(50) NOT NULL ,
`domestic_departure_city` varchar(50) NOT NULL ,
`domestic_days` int(5) ,
`presale_price` int(8) NOT NULL ,
`tour_price` int(8) NOT NULL ,
`number_people_group` int(8) NOT NULL ,
`personal_price` int(8) NOT NULL ,
`domestic_number_attractions` int(6) NOT NULL ,
PRIMARY KEY (`domestic_travel_route_id`)
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO domestic_travel_routes (imp_date, domestic_travel_route_id,domestic_travel_route_name,travel_agency_id,domestic_departure_city,domestic_days,presale_price,tour_price,number_people_group,personal_price,domestic_number_attractions)
VALUES (?, ?, ?,?, ?, ?,?,?, ?, ?,?)
'''
data = [
(imp_date,"item_enterprise_7_66","桂林深度跟团游","item_enterprise_7_60",'北京',4,2500,2000,2,3000,10),
(imp_date,"item_enterprise_7_67","厦门休闲游","item_enterprise_7_56",'天津',8,6500,5000,5,7000,20),
(imp_date,"item_enterprise_7_68","重庆红色之旅","item_enterprise_7_60",'上海',4,6500,2000,5,3000,20),
(imp_date,"item_enterprise_7_69","云南古城游","item_enterprise_7_59",'哈尔滨',4,6500,2000,5,7000,20),
(imp_date,"item_enterprise_7_70","上海时尚游","item_enterprise_7_59",'深圳',4,6500,2000,5,7000,10)
]
cursor.executemany(insert_data_query, data)
conn.commit()
create_table_query = '''
CREATE TABLE IF NOT EXISTS `cruise_route` (
`imp_date` varchar(50) ,
`cruise_route_id` varchar(50) NOT NULL ,
`cruise_route_name` varchar(50) NOT NULL ,
`travel_agency_id` varchar(50) NOT NULL ,
`cruise_departure_city` varchar(50) NOT NULL ,
`cruise_days` int(5) ,
`interior_cabin_price` int(8) NOT NULL ,
`sea_view_room_price` int(8) NOT NULL ,
`balcony_room_price` int(8) NOT NULL ,
`sailing_area` varchar(50) NOT NULL ,
`cruise_line` varchar(50) NOT NULL ,
PRIMARY KEY (`cruise_route_id`)
)
'''
cursor.execute(create_table_query)
insert_data_query = '''
INSERT INTO cruise_route (imp_date, cruise_route_id,cruise_route_name,travel_agency_id,cruise_departure_city,cruise_days,interior_cabin_price,sea_view_room_price,balcony_room_price,sailing_area,cruise_line)
VALUES (?, ?, ?,?, ?, ?,?,?, ?, ?,?)
'''
data = [
(imp_date,"item_enterprise_7_71","南极摄影旅游","item_enterprise_7_57",'大连',6,4399,4799,5299,"日本航线","皇家加勒比国际游轮"),
(imp_date,"item_enterprise_7_72","地中海巡游","item_enterprise_7_58",'天津',10,6399,6799,7399,"韩国航线","海洋亚特兰蒂游轮"),
(imp_date,"item_enterprise_7_73","超凡体验来自未来的游轮","item_enterprise_7_60",'上海',10,4399,4799,5299,"南极航线","庞洛游轮"),
(imp_date,"item_enterprise_7_74","超凡体验来自未来的游轮","item_enterprise_7_60",'深圳',10,6399,6799,5299,"南极航线","庞洛游轮"),
(imp_date,"item_enterprise_7_75","超凡体验来自未来的游轮","item_enterprise_7_60",'天津',10,6399,4799,5299,"韩国航线","海洋亚特兰蒂游轮")
]
cursor.executemany(insert_data_query, data)
conn.commit()
conn.close()
def build_table():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
db_exist=os.path.exists(db_file)
if db_exist:
os.remove(db_file)
print("db_file removed!")
print(db_file)
build_internet(db_path,-1)
if __name__ == '__main__':
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
db_exist=os.path.exists(db_file)
if db_exist:
os.remove(db_file)
print("db_file removed!")
print(db_file)
build_internet(db_path,-1)
#build_china_travel_agency(path)
# sql="SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id"
# query_sql(sql)
# sql =" select * from ( SELECT brand_name, company_name, sum(revenue_proportion), sum(annual_turnover) FROM (select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover`, `revenue_proportion`from(select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`from(select `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`, `sys_imp_date`, `company_name`, `brand_name`from(select `src1_brand`.`sys_imp_date` as `sys_imp_date`, `src1_company`.`annual_turnover` as `annual_turnover`, `src1_company_revenue`.`revenue_proportion` as `revenue_proportion`, `src1_brand`.`company_id` as `company_id`, `src1_company`.`company_name` as `company_name`, `src1_brand`.`brand_name` as `brand_name`, `src1_brand`.`brand_id` as `brand_id`from(select `annual_turnover` as `annual_turnover`, `imp_date` as `sys_imp_date`, `company_id`, `company_name` as `company_name`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `company_name`, `headquarter_address`, `company_established_time`, `founder`, `ceo`, `annual_turnover`, `employee_count`from`company`) as `company`) as `src1_company`inner join (select `revenue_proportion` as `revenue_proportion`, `imp_date` as `sys_imp_date`, `company_id`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `brand_id`, `revenue_proportion`, `profit_proportion`, `expenditure_proportion`from`company_revenue`) as `company_revenue`) as `src1_company_revenue` on `src1_company`.`company_id` = `src1_company_revenue`.`company_id`inner join (select `imp_date` as `sys_imp_date`, `company_id`, `brand_name` as `brand_name`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `brand_id`, `brand_name`, `brand_established_time`, `company_id`, `legal_representative`, `registered_capital`from`brand`) as `brand`) as `src1_brand` on `src1_company`.`company_id` = `src1_brand`.`company_id`) as `src11_`) as `company_company_revenue_brand_0`) as `company_company_revenue_brand_1`) t_103 WHERE sys_imp_date = '2024-01-11' GROUP BY brand_name, company_name, sys_imp_date ORDER BY sum(revenue_proportion) DESC ) a limit 1000 "
# a=query_sql(sql)
# print(a[0][0])
# import pymysql
#
# db = pymysql.connect(host="11.154.212.211", user="semantic", password="semantic2023", database="internet", port=3306)
#
# # 使用cursor()方法获取操作游标
# cursor = db.cursor()
#
# # 使用execute方法执行SQL语句
# cursor.execute("select * from company")
#
# # 使用 fetchone() 方法获取一条数据
# data = cursor.fetchone()
# print(data)

View File

View File

@@ -0,0 +1,4 @@
chat_id: 3
agent_id: 4
domain: internet
url: http://localhost:9080

View File

@@ -0,0 +1,100 @@
SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet
SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet
SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet
SELECT T3.company_name, T3.headquarter_address, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet
SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id WHERE T1.profit_proportion <= 0.1 internet
SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id WHERE T1.profit_proportion < 0.1 internet
SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id WHERE T1.profit_proportion > 0.1 internet
SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital >= 100000000 GROUP BY T1.brand_id ORDER BY avg(T1.revenue_proportion) DESC LIMIT 1 internet
SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital > 100000000 GROUP BY T1.brand_id ORDER BY count(*) ASC LIMIT 5 internet
SELECT T2.brand_name, avg(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, avg(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital >= 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, sum(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, max(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, max(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital > 100000000 GROUP BY T1.brand_id HAVING avg(T1.revenue_proportion) = 0.5 internet
SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id HAVING count(*) = 5 internet
SELECT T2.brand_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id HAVING avg(T1.expenditure_proportion) <= 0.45 internet
SELECT T2.brand_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id HAVING count(*) <= 5 internet
SELECT T2.legal_representative, T2.brand_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name, sum(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name, max(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING count(*) <= 5 internet
SELECT T2.legal_representative, T2.brand_name FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING avg(T1.revenue_proportion) > 0.5 internet
SELECT T2.brand_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING avg(T1.profit_proportion) <= 0.6 internet
SELECT T2.brand_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING count(*) = 5 internet
SELECT T2.brand_name, T2.legal_representative, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id ORDER BY avg(T1.profit_proportion) DESC LIMIT 1 internet
SELECT T2.brand_name, T2.legal_representative, sum(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id ORDER BY count(*) DESC LIMIT 3 internet
SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.profit_proportion ASC internet
SELECT T2.brand_name, T3.company_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.expenditure_proportion ASC LIMIT 3 internet
SELECT T2.brand_name, T3.company_name, T1.revenue_proportion, T1.profit_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.expenditure_proportion DESC LIMIT 3 internet
SELECT T2.brand_name, T3.company_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.profit_proportion DESC LIMIT 3 internet
SELECT brand_name FROM brand WHERE legal_representative NOT IN (SELECT legal_representative FROM brand GROUP BY legal_representative HAVING avg(registered_capital) < 1000000) internet
SELECT T1.brand_name, T2.company_name, T1.legal_representative, T2.annual_turnover FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T1.brand_name, T2.company_name, T1.registered_capital FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T1.brand_name, T2.company_name, T1.registered_capital, T2.annual_turnover FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T1.brand_name, T2.company_name, T1.registered_capital, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T1.brand_name, T2.company_name, T1.legal_representative, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T2.company_name, T2.annual_turnover FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet
SELECT T3.company_name, T3.headquarter_address, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet
SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet
SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet
SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital >= 100000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 internet
SELECT T2.company_name, T2.annual_turnover, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 AND T2.annual_turnover <= 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 AND T2.annual_turnover < 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 AND T2.annual_turnover < 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 AND T2.annual_turnover >= 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital >= 100000000 AND T2.annual_turnover > 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 AND T2.annual_turnover > 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 AND T2.annual_turnover > 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 AND T2.annual_turnover < 28800000000 internet
SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 AND T2.annual_turnover <= 28800000000 internet
SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id WHERE T1.profit <= 50000000000 internet
SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id WHERE T1.revenue_growth_year_on_year >= 1 internet
SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY avg(T1.revenue_proportion) ASC LIMIT 5 internet
SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY avg(T1.revenue_proportion) DESC LIMIT 1 internet
SELECT T2.brand_name, T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital > 100000000 GROUP BY T1.brand_id ORDER BY avg(T1.revenue) DESC LIMIT 1 internet
SELECT T2.brand_name, T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id ORDER BY avg(T1.revenue) DESC LIMIT 1 internet
SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY avg(T1.registered_capital) DESC LIMIT 5 internet
SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY sum(T1.registered_capital) ASC LIMIT 5 internet
SELECT T2.company_name, max(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, sum(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, max(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, sum(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, avg(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet
SELECT T2.brand_name, avg(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, avg(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital >= 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, min(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, max(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet
SELECT T2.brand_name, avg(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet
SELECT T2.company_name, min(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, min(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, avg(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, max(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, sum(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id internet
SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id HAVING sum(T1.revenue_proportion) <= 0.5 internet
SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id HAVING sum(T1.revenue_proportion) <= 0.5 internet
SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id HAVING count(*) <= 5 internet
SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id HAVING count(*) < 5 internet
SELECT T2.company_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id HAVING count(*) > 5 internet
SELECT T2.company_name, max(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id HAVING avg(T1.profit_proportion) > 0.75 internet
SELECT T2.brand_name, max(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id HAVING sum(T1.profit_growth_year_on_year) <= 1000000 internet
SELECT T2.brand_name, min(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id HAVING count(*) < 5 internet
SELECT T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id HAVING count(*) < 5 internet
SELECT T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id HAVING count(*) >= 5 internet
SELECT T2.headquarter_address, T2.company_name, max(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet
SELECT T2.headquarter_address, T2.company_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet
SELECT T2.headquarter_address, T2.company_name, sum(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet
SELECT T2.headquarter_address, T2.company_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet
SELECT T2.legal_representative, T2.brand_name, avg(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name, min(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name, sum(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.legal_representative, T2.brand_name, max(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet
SELECT T2.headquarter_address, T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet
SELECT T2.headquarter_address, T2.company_name, max(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet

View File

@@ -0,0 +1,100 @@
在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司的年营业额
在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例
在各公司所有品牌收入排名中,给出每一个品牌和其法人,其所在公司以及收入占该公司的总收入比例
在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司总部所在地
在公司各品牌收入排名的利润占比最多10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比
在公司各品牌收入排名的利润占比小于10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比
在公司各品牌收入排名的利润占比不止10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比
注册资本不小于1亿的品牌中哪个品牌的平均营收占比最大并给出它的法定代表人
注册资本大于1亿的品牌中哪5个品牌收入最少并给出它们的法定代表人
找到注册资本少于一个亿的品牌及其法人,并给出对应公司的平均营收占比
给出注册资本不少于一亿的品牌及其法人,并给出对应的公司的平均营收占比
给出注册资本不超过一亿的的品牌及其法人,并给出对应的公司的总营收占比
给出注册资本少于一亿的品牌及其法人,并给出对应的公司的最大营收占比
给出注册资本不超过1亿的品牌及其法人并找出对应的公司的最大营收占比
在注册资本超过一亿的公司中给出公司各品牌收入品牌的平均营收占比正好50%的品牌及其法人
在注册资本不超过一亿的公司中,给出个品牌收入排名五的品牌及其法人
在各品牌收入在公司排名中当品牌的注册资本不大于1亿时给出公司各品牌收入排名的支出占比的平均值小于等于45%的那些品牌的名称以及公司各品牌收入排名的营收占比的平均值
在各品牌收入在公司排名中当品牌的注册资本小于1亿时给出公司各品牌收入排名数量小于等于5的那些品牌的名称以及公司各品牌收入排名的营收占比的最小值
在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的平均值
在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的最小值
在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的总和
在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的最大值
在各品牌收入在公司排名中给出收入排名不超过5的品牌及其法人
在各品牌收入在公司排名中给出在收入排名中的平均营收占比超过50%的品牌及其法人
在各品牌收入在公司排名中当公司各品牌收入排名的利润占比的平均值小于等于60%时,给出品牌的名称以及公司各品牌收入排名的营收占比的平均值
在各品牌收入在公司排名中当公司各品牌收入排名数量等于5时给出品牌的名称以及公司各品牌收入排名的营收占比的最小值
哪个品牌收入的平均利润占比最大,给出品牌的法定代表人,以及其收入平均营收占比
哪3个品牌的收入最多给出品牌的法定代表人以及其收入总营收占比
在公司各品牌收入排名的利润占比由少到多排列,给出对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比
在公司各品牌收入排名的支出占比最少时给出排名前3对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比
在公司各品牌收入排名的支出占比最多时给出排名前3对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比公司各品牌收入排名的利润占比
在公司各品牌收入排名的利润占比最多时给出排名前3对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比
哪些法人的品牌的不在注册资本少于100万中这些法定代表人的品牌都是哪些
给出每一个品牌和其法人,所属的公司以及该公司的年营业额
给出每一个品牌和注册时所用资本,以及所属的公司
给出每一个品牌和注册时所用资本,所属的公司以及该公司的年营业额
给出每一个品牌和注册时所用资本,所属的公司和总部所在地
给出每一个品牌和其法人,所属的公司以及总部所在城市
有自己品牌的公司有哪些?给出这些公司和总部所在地
有自己品牌的公司有哪些?给出这些公司和年营业额
在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司总部所在地点,并给出该品牌近几年的营收
在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司年营业额,并给出该品牌近几年的营收
在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司法人,并给出该品牌近几年的营收
在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司,以及该品牌近几年的营收
在品牌的注册资本至少1亿时给出公司的名称以及公司的总部地点品牌的法定代表人
在品牌的注册资本少于1亿时给出公司的名称以及公司的总部地点品牌的法定代表人
在品牌的注册资本超过1亿时给出公司的名称以及公司的总部地点品牌的法定代表人
在品牌的注册资本最多1亿时给出公司的名称以及公司的年营业额品牌的法定代表人
找到注册资本不止1亿且年营业额不超过288亿的公司以及给出总部地点法人
找出注册资本不止一亿且年营业额低于288亿的公司以及总部地点和法人
找出注册资本不到1亿且年营业额少于288亿的公司以及给出总部地点和法人
给出注册资本不超过1亿且年营业额不少于288亿的公司总部地点和法人
给出注册资本不低于一亿且年营业额不止288亿的公司以及总部地点和法人
给出注册资本不超过一亿且年营业额超过288亿的公司总部在哪法人是谁
给出注册资本超过一亿且年营业额不止288亿的公司以及总部在哪法人是谁
给出注册资本不超过一亿且年营业额少于288亿的公司以及总部地点和法人
找出注册资本少于一亿且年营业额不超过288亿的公司以及总部在哪法人是谁
在公司品牌历年收入的利润最多500亿时给出公司的名称品牌的名称并给出公司品牌历年收入的营收
在公司品牌历年收入的营收同比增长至少100%时,给出公司的名称品牌的名称并给出公司品牌历年收入的营收
年营业额不小于288亿的公司中哪5个公司的平均营收占比最少,并给出它们的总部地点
年营业额不小于288亿的公司中哪个公司的平均营收占比最大并给出它的总部地点
注册资本大于1亿的品牌中哪个品牌历年收入的平均营收最大并给出它的法定代表人
注册资本不大于1亿的品牌中哪个品牌历年收入的平均营收最大并给出它的法定代表人
年营业额不小于288亿的公司中哪5个公司品牌的平均注册资本最多并给出它们的总部地点
年营业额不小于288亿的公司中哪5个公司品牌的平均注册资本总共最少并给出它们的总部地点
找到年营业额不少于288亿的公司及总部地点并给出对应的公司的最高营收占比
给出年营业额少于288亿的公司及总部地点并给出对应的公司的总营收占比
给出年营业额不止288亿的公司及总部地点并给出对应的公司的最大营收占比
找出年营业额不少于288亿的公司及总部地点并给出对应的公司的总营收占比
请找出年营业额不超过288亿的公司及总部地点并给出对应的公司的平均营收占比
找到品牌的注册资本不大于1亿品牌的法定代表人并给出公司品牌历年收入的营收的平均值
找到品牌的注册资本不小于1亿品牌的法定代表人并给出公司品牌历年收入的营收的平均值
找到品牌的注册资本小于1亿品牌的法定代表人并给出公司品牌历年收入的营收的最小值
找到品牌的注册资本不大于1亿品牌的法定代表人并给出公司品牌历年收入的营收的最大值
找到品牌的注册资本小于1亿品牌的法定代表人并给出公司品牌历年收入的营收的平均值
给出年营业额超过288亿的公司及总部地点并给出对应的品牌中的最小注册资本
给出年营业额不超过288亿的公司及总部地点并给出对应的品牌中的最小注册资本
给出不超过288亿年营业额的公司及总部地点并给出这些品牌中的平均注册资本
给出不超过288亿年营业额的公司及其总部地点并给出这些品牌的的最大注册资本
给出年营业额超过288亿的公司及其总部地点并给出这些品牌的的总注册资本
给出年营业额不超过288亿的各公司品牌中给出收入排名中的营收占比加起来不超过50%的公司及其总部地点
给出年营业额低于288亿的各公司的各品牌中给出收入排名中的总营收占比不超过50%的公司及总部地点中
在年营业额不少于288亿的公司中给出品牌不超过5个的公司及其总部地点
在年营业额不超过288亿的公司中给出品牌少于5个的公司及其总部地点
在各公司的各品牌收入排名中当公司的年营业额大于288亿时给出公司各品牌收入排名数量大于5的那些公司的名称以及公司各品牌收入排名的营收占比的最小值
在各公司的各品牌收入排名中当公司的年营业额小于288亿时给出公司各品牌收入排名的利润占比的平均值大于75%的那些公司的名称以及公司各品牌收入排名的营收占比的最大值
在各品牌的历年收入中当品牌的注册资本不大于1亿时给出公司品牌历年收入的利润同比增长的总和小于等于1000000的那些品牌的名称以及公司品牌历年收入的营收的最大值
在各品牌的历年收入中当品牌的注册资本小于1亿时给出公司品牌历年收入数量小于5的那些品牌的名称以及公司品牌历年收入的营收的最小值
在各品牌所属的公司中当公司的年营业额大于288亿时给出品牌数量小于5的那些公司的名称以及品牌的注册资本的总和
在各品牌所属的公司中当公司的年营业额小于288亿时给出品牌数量大于等于5的那些公司的名称以及品牌的注册资本的总和
在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的最大营收占比
在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的最小营收占比
在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的总营收占比
在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的平均营收占比
在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的平均值
在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的最小值
在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的总和
在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的最大值
在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的总和
在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的最大值

View File

@@ -0,0 +1,758 @@
[
{
"db_id": "internet",
"table_names": [
"company",
"brand",
"company_brand_revenue",
"company_revenue"
],
"column_names": [
[
-1,
"*"
],
[
0,
"company_id"
],
[
0,
"company_name"
],
[
0,
"headquarter_address"
],
[
0,
"company_established_time"
],
[
0,
"founder"
],
[
0,
"ceo"
],
[
0,
"annual_turnover"
],
[
0,
"employee_count"
],
[
1,
"brand_id"
],
[
1,
"brand_name"
],
[
1,
"brand_established_time"
],
[
1,
"company_id"
],
[
1,
"legal_representative"
],
[
1,
"registered_capital"
],
[
2,
"year_time"
],
[
2,
"brand_id"
],
[
2,
"revenue"
],
[
2,
"profit"
],
[
2,
"revenue_growth_year_on_year"
],
[
2,
"profit_growth_year_on_year"
],
[
3,
"company_id"
],
[
3,
"brand_id"
],
[
3,
"revenue_proportion"
],
[
3,
"profit_proportion"
],
[
3,
"expenditure_proportion"
]
],
"table_names_original": [
"company",
"brand",
"company_brand_revenue",
"company_revenue"
],
"column_names_original": [
[
-1,
"*"
],
[
0,
"company_id"
],
[
0,
"company_name"
],
[
0,
"headquarter_address"
],
[
0,
"company_established_time"
],
[
0,
"founder"
],
[
0,
"ceo"
],
[
0,
"annual_turnover"
],
[
0,
"employee_count"
],
[
1,
"brand_id"
],
[
1,
"brand_name"
],
[
1,
"brand_established_time"
],
[
1,
"company_id"
],
[
1,
"legal_representative"
],
[
1,
"registered_capital"
],
[
2,
"year_time"
],
[
2,
"brand_id"
],
[
2,
"revenue"
],
[
2,
"profit"
],
[
2,
"revenue_growth_year_on_year"
],
[
2,
"profit_growth_year_on_year"
],
[
3,
"company_id"
],
[
3,
"brand_id"
],
[
3,
"revenue_proportion"
],
[
3,
"profit_proportion"
],
[
3,
"expenditure_proportion"
]
],
"column_types": [
"text",
"number",
"text",
"text",
"time",
"text",
"time",
"number",
"number",
"number",
"text",
"time",
"text",
"text",
"number",
"time",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number"
],
"foreign_keys": [
[
12,
1
],
[
21,
1
],
[
22,
9
],
[
16,
9
]
],
"primary_keys": [
1,
9
]
},
{
"db_id": "china_travel_agency",
"table_names": [
"travel_agency",
"outbound_travel_routes",
"country_outbound_travel",
"domestic_travel_routes",
"cruise_route"
],
"column_names": [
[
-1,
"*"
],
[
0,
"travel_agency_id"
],
[
0,
"travel_agency_name"
],
[
0,
"travel_agency_level"
],
[
0,
"number_countrie_outbound_travel"
],
[
0,
"number_domestic_tourist_cities"
],
[
0,
"number_outbound_travel_routes"
],
[
0,
"number_domestic_travel_routes"
],
[
0,
"asia_ranking"
],
[
0,
"number_overseas_tourists_received"
],
[
0,
"number_overseas_companies"
],
[
0,
"number_holding_subsidiaries"
],
[
0,
"number_traveling_salesmen_business_relationships"
],
[
0,
"number_duty_free_shops"
],
[
1,
"outbound_route_id"
],
[
1,
"outbound_route_name"
],
[
1,
"travel_agency_id"
],
[
1,
"outbound_departure_city"
],
[
1,
"outbound_days"
],
[
1,
"adult_price"
],
[
1,
"child_price"
],
[
1,
"countries"
],
[
1,
"attractions"
],
[
1,
"total_ticket_price"
],
[
2,
"outbound_travel_route_id"
],
[
2,
"nation"
],
[
2,
"travel_days"
],
[
2,
"outbound_number_attractions"
],
[
3,
"domestic_travel_route_id"
],
[
3,
"domestic_travel_route_name"
],
[
3,
"travel_agency_id"
],
[
3,
"domestic_departure_city"
],
[
3,
"domestic_days"
],
[
3,
"presale_price"
],
[
3,
"tour_price"
],
[
3,
"number_people_group"
],
[
3,
"personal_price"
],
[
3,
"domestic_number_attractions"
],
[
4,
"cruise_route_id"
],
[
4,
"cruise_route_name"
],
[
4,
"travel_agency_id"
],
[
4,
"cruise_departure_city"
],
[
4,
"cruise_days"
],
[
4,
"interior_cabin_price"
],
[
4,
"sea_view_room_price"
],
[
4,
"balcony_room_price"
],
[
4,
"sailing_area"
],
[
4,
"cruise_line"
]
],
"table_names_original": [
"travel_agency",
"outbound_travel_routes",
"country_outbound_travel",
"domestic_travel_routes",
"cruise_route"
],
"column_names_original": [
[
-1,
"*"
],
[
0,
"travel_agency_id"
],
[
0,
"travel_agency_name"
],
[
0,
"travel_agency_level"
],
[
0,
"number_countrie_outbound_travel"
],
[
0,
"number_domestic_tourist_cities"
],
[
0,
"number_outbound_travel_routes"
],
[
0,
"number_domestic_travel_routes"
],
[
0,
"asia_ranking"
],
[
0,
"number_overseas_tourists_received"
],
[
0,
"number_overseas_companies"
],
[
0,
"number_holding_subsidiaries"
],
[
0,
"number_traveling_salesmen_business_relationships"
],
[
0,
"number_duty_free_shops"
],
[
1,
"outbound_route_id"
],
[
1,
"outbound_route_name"
],
[
1,
"travel_agency_id"
],
[
1,
"outbound_departure_city"
],
[
1,
"outbound_days"
],
[
1,
"adult_price"
],
[
1,
"child_price"
],
[
1,
"countries"
],
[
1,
"attractions"
],
[
1,
"total_ticket_price"
],
[
2,
"outbound_travel_route_id"
],
[
2,
"nation"
],
[
2,
"travel_days"
],
[
2,
"outbound_number_attractions"
],
[
3,
"domestic_travel_route_id"
],
[
3,
"domestic_travel_route_name"
],
[
3,
"travel_agency_id"
],
[
3,
"domestic_departure_city"
],
[
3,
"domestic_days"
],
[
3,
"presale_price"
],
[
3,
"tour_price"
],
[
3,
"number_people_group"
],
[
3,
"personal_price"
],
[
3,
"domestic_number_attractions"
],
[
4,
"cruise_route_id"
],
[
4,
"cruise_route_name"
],
[
4,
"travel_agency_id"
],
[
4,
"cruise_departure_city"
],
[
4,
"cruise_days"
],
[
4,
"interior_cabin_price"
],
[
4,
"sea_view_room_price"
],
[
4,
"balcony_room_price"
],
[
4,
"sailing_area"
],
[
4,
"cruise_line"
]
],
"column_types": [
"text",
"number",
"text",
"text",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"text",
"number",
"text",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"text",
"number",
"number",
"number",
"text",
"number",
"text",
"number",
"number",
"number",
"number",
"number",
"number",
"number",
"text",
"number",
"text",
"number",
"number",
"number",
"number",
"text",
"text"
],
"foreign_keys": [
[
40,
1
],
[
24,
14
],
[
30,
1
],
[
16,
1
]
],
"primary_keys": [
1,
14,
28,
38
]
}
]

922
evaluation/evaluation.py Normal file
View File

@@ -0,0 +1,922 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
from __future__ import print_function
import sqlparse
import logging
import os, sys
import json
import sqlite3
import traceback
import argparse
import yaml
import re
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
from build_pred_result import read_query,get_pred_result
from build_tables import build_table
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
#print(plist)
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
log_list=[]
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
questions=read_query(query_path)
index=0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
# db = os.path.join(db_dir, db, db + ".sqlite")
db = os.path.join(db_dir,db + ".db")
print(db)
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
p_sql = p_str
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if not exec_score:
element={}
element["query"]=questions[index]
element["gold_sql"]=g_str
element["pred_sql"]=p_str
log_list.append(element)
if exec_score:
scores[hardness]['exec'] += 1.0
scores['all']['exec'] += 1.0
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
index=index+1
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
print(scores['all']['exec'])
current_directory = os.path.dirname(os.path.abspath(__file__))
file_name=current_directory+"/eval.json"
json_exist=os.path.exists(file_name)
if json_exist:
os.remove(file_name)
with open(file_name, 'w') as json_file:
json.dump(log_list, json_file, indent=4, ensure_ascii=False)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
columns_tuple = cursor.description
p_fields = [field_tuple[0] for field_tuple in columns_tuple]
for index in range(0,len(p_fields)):
p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower())
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, p_fields):
rmap = {}
for i in range(0,len(p_fields)):
if p_fields[i] != "sys_imp_date":
value_list= [r[i] for r in res]
value_list.sort()
rmap[p_fields[i]] =value_list
return rmap
g_fields = parse_sql(g_str)
#print("p_res_map:{}".format(res_map(p_res, p_fields)))
#print("q_res_map:{}".format(res_map(q_res, g_fields)))
return res_map(p_res, p_fields) == res_map(q_res, g_fields)
def parse_sql(sql):
# 使用 sqlparse 库解析 SQL 查询语句
parsed = sqlparse.parse(sql)[0]
# 获取查询类型SELECT、INSERT、UPDATE 或 DELETE
query_type = parsed.get_type()
# 获取查询目标(表名、字段列表、值列表等)
if query_type == 'SELECT':
target = parse_select(parsed)
else:
target = None
return target
def parse_select(parsed):
# 获取字段列表
fields = []
for token in parsed.tokens:
#
if isinstance(token, sqlparse.sql.IdentifierList):
for identifier in token.get_identifiers():
fields.append(identifier.value.replace("`", "")
.replace("T1.", "").replace("T2.", "")
.replace("T3.", "").replace("T4.", "")
.replace("T5.", "").replace("T6.", ""))
if(len(fields)):
break
return fields
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
def get_evaluation_result():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_dir=current_directory+"/data"
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
pred = current_directory+"/data/"+"pred_example_dusql.txt"
gold = current_directory+"/data/"+"gold_example_dusql.txt"
table= current_directory+"/data/"+"tables_dusql.json"
query_path=current_directory+"/data/"+config["domain"]+".txt"
etype="exec"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps,query_path)
def remove_unused_file():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
pred_file = current_directory+"/data/"+"pred_example_dusql.txt"
db_exist=os.path.exists(db_file)
if db_exist:
os.remove(db_file)
print("db_file removed!")
pred_exist=os.path.exists(pred_file)
if pred_exist:
os.remove(pred_file)
print("pred_file removed!")
if __name__ == "__main__":
build_table()
get_pred_result()
get_evaluation_result()
remove_unused_file()

12
evaluation/evaluation.sh Normal file
View File

@@ -0,0 +1,12 @@
#!/usr/bin/env bash
path=$(pwd)
echo ${path}
python_path=${PYTHON_PATH:-"python3"}
pip_path=${PIP_PATH:-"pip3"}
requirementPath=$path/requirements.txt
${pip_path} install -r ${requirementPath}
echo "install python modules success"
python $path/evaluation.py

566
evaluation/process_sql.py Normal file
View File

@@ -0,0 +1,566 @@
################################
# Assumptions:
# 1. sql is correct
# 2. only table name has alias
# 3. only one intersect/union/except
#
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
import json
import sqlite3
from nltk import word_tokenize
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
class Schema:
"""
Simple schema which maps table&column to a unique identifier
"""
def __init__(self, schema):
self._schema = schema
self._idMap = self._map(self._schema)
@property
def schema(self):
return self._schema
@property
def idMap(self):
return self._idMap
def _map(self, schema):
idMap = {'*': "__all__"}
id = 1
for key, vals in schema.items():
for val in vals:
idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__"
id += 1
for key in schema:
idMap[key.lower()] = "__" + key.lower() + "__"
id += 1
return idMap
def get_schema(db):
"""
Get database's schema, which is a dict with table name as key
and list of column names as value
:param db: database path
:return: schema dict
"""
schema = {}
conn = sqlite3.connect(db)
cursor = conn.cursor()
# fetch table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
tables = [str(table[0].lower()) for table in cursor.fetchall()]
# tables.remove("internet.company")
# tables.remove("internet.brand")
# tables.remove("internet.company_revenue")
# tables.remove("internet.company_brand_revenue")
# fetch table info
for table in tables:
cursor.execute("PRAGMA table_info({})".format(table))
schema[table] = [str(col[1].lower()) for col in cursor.fetchall()]
return schema
def get_schema_from_json(fpath):
with open(fpath) as f:
data = json.load(f)
schema = {}
for entry in data:
table = str(entry['table'].lower())
cols = [str(col['column_name'].lower()) for col in entry['col_data']]
schema[table] = cols
return schema
def tokenize(string):
string = str(string)
string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem??
quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
assert len(quote_idxs) % 2 == 0, "Unexpected quote"
# keep string value as token
vals = {}
for i in range(len(quote_idxs)-1, -1, -2):
qidx1 = quote_idxs[i-1]
qidx2 = quote_idxs[i]
val = string[qidx1: qidx2+1]
key = "__val_{}_{}__".format(qidx1, qidx2)
string = string[:qidx1] + key + string[qidx2+1:]
vals[key] = val
toks = [word.lower() for word in word_tokenize(string)]
# replace with string value token
for i in range(len(toks)):
if toks[i] in vals:
toks[i] = vals[toks[i]]
# find if there exists !=, >=, <=
eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
eq_idxs.reverse()
prefix = ('!', '>', '<')
for eq_idx in eq_idxs:
pre_tok = toks[eq_idx-1]
if pre_tok in prefix:
toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]
return toks
def scan_alias(toks):
"""Scan the index of 'as' and build the map for all alias"""
as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as']
alias = {}
for idx in as_idxs:
alias[toks[idx+1]] = toks[idx-1]
return alias
def get_tables_with_alias(schema, toks):
tables = scan_alias(toks)
for key in schema:
assert key not in tables, "Alias {} has the same name in table".format(key)
tables[key] = key
return tables
def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, column id
"""
tok = toks[start_idx]
if tok == "*":
return start_idx + 1, schema.idMap[tok]
if '.' in tok: # if token is a composite
alias, col = tok.split('.')
key = tables_with_alias[alias] + "." + col
return start_idx+1, schema.idMap[key]
assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty"
for alias in default_tables:
table = tables_with_alias[alias]
if tok in schema.schema[table]:
key = table + "." + tok
return start_idx+1, schema.idMap[key]
assert False, "Error col: {}".format(tok)
def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
"""
:returns next idx, (agg_op id, col_id)
"""
idx = start_idx
len_ = len(toks)
isBlock = False
isDistinct = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
assert idx < len_ and toks[idx] == '('
idx += 1
if toks[idx] == "distinct":
idx += 1
isDistinct = True
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
assert idx < len_ and toks[idx] == ')'
idx += 1
return idx, (agg_id, col_id, isDistinct)
if toks[idx] == "distinct":
idx += 1
isDistinct = True
agg_id = AGG_OPS.index("none")
idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (agg_id, col_id, isDistinct)
def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
col_unit1 = None
col_unit2 = None
unit_op = UNIT_OPS.index('none')
idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if idx < len_ and toks[idx] in UNIT_OPS:
unit_op = UNIT_OPS.index(toks[idx])
idx += 1
idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
return idx, (unit_op, col_unit1, col_unit2)
def parse_table_unit(toks, start_idx, tables_with_alias, schema):
"""
:returns next idx, table id, table name
"""
idx = start_idx
len_ = len(toks)
key = tables_with_alias[toks[idx]]
if idx + 1 < len_ and toks[idx+1] == "as":
idx += 3
else:
idx += 1
return idx, schema.idMap[key], key
def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, val = parse_sql(toks, idx, tables_with_alias, schema)
elif "\"" in toks[idx]: # token is a string value
val = toks[idx]
idx += 1
else:
try:
val = float(toks[idx])
idx += 1
except:
end_idx = idx
while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\
and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS:
end_idx += 1
idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables)
idx = end_idx
if isBlock:
assert toks[idx] == ')'
idx += 1
return idx, val
def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
conds = []
while idx < len_:
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
not_op = False
if toks[idx] == 'not':
not_op = True
idx += 1
assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx])
op_id = WHERE_OPS.index(toks[idx])
idx += 1
val1 = val2 = None
if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
assert toks[idx] == 'and'
idx += 1
idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
else: # normal case: single value
idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables)
val2 = None
conds.append((not_op, op_id, val_unit, val1, val2))
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS):
break
if idx < len_ and toks[idx] in COND_OPS:
conds.append(toks[idx])
idx += 1 # skip and/or
return idx, conds
def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None):
idx = start_idx
len_ = len(toks)
assert toks[idx] == 'select', "'select' not found"
idx += 1
isDistinct = False
if idx < len_ and toks[idx] == 'distinct':
idx += 1
isDistinct = True
val_units = []
while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS:
agg_id = AGG_OPS.index("none")
if toks[idx] in AGG_OPS:
agg_id = AGG_OPS.index(toks[idx])
idx += 1
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append((agg_id, val_unit))
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
return idx, (isDistinct, val_units)
def parse_from(toks, start_idx, tables_with_alias, schema):
"""
Assume in the from clause, all table units are combined with join
"""
assert 'from' in toks[start_idx:], "'from' not found"
len_ = len(toks)
idx = toks.index('from', start_idx) + 1
default_tables = []
table_units = []
conds = []
while idx < len_:
isBlock = False
if toks[idx] == '(':
isBlock = True
idx += 1
if toks[idx] == 'select':
idx, sql = parse_sql(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['sql'], sql))
else:
if idx < len_ and toks[idx] == 'join':
idx += 1 # skip join
idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema)
table_units.append((TABLE_TYPE['table_unit'],table_unit))
default_tables.append(table_name)
if idx < len_ and toks[idx] == "on":
idx += 1 # skip on
idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
if len(conds) > 0:
conds.append('and')
conds.extend(this_conds)
if isBlock:
assert toks[idx] == ')'
idx += 1
if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
break
return idx, table_units, conds, default_tables
def parse_where(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'where':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
col_units = []
if idx >= len_ or toks[idx] != 'group':
return idx, col_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables)
col_units.append(col_unit)
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, col_units
def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
val_units = []
order_type = 'asc' # default type is 'asc'
if idx >= len_ or toks[idx] != 'order':
return idx, val_units
idx += 1
assert toks[idx] == 'by'
idx += 1
while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")):
idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables)
val_units.append(val_unit)
if idx < len_ and toks[idx] in ORDER_OPS:
order_type = toks[idx]
idx += 1
if idx < len_ and toks[idx] == ',':
idx += 1 # skip ','
else:
break
return idx, (order_type, val_units)
def parse_having(toks, start_idx, tables_with_alias, schema, default_tables):
idx = start_idx
len_ = len(toks)
if idx >= len_ or toks[idx] != 'having':
return idx, []
idx += 1
idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables)
return idx, conds
def parse_limit(toks, start_idx):
idx = start_idx
len_ = len(toks)
if idx < len_ and toks[idx] == 'limit':
idx += 2
return idx, int(toks[idx-1])
return idx, None
def parse_sql(toks, start_idx, tables_with_alias, schema):
isBlock = False # indicate whether this is a block of sql/sub-sql
len_ = len(toks)
idx = start_idx
sql = {}
if toks[idx] == '(':
isBlock = True
idx += 1
# parse from clause in order to get default tables
from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema)
sql['from'] = {'table_units': table_units, 'conds': conds}
# select clause
_, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables)
idx = from_end_idx
sql['select'] = select_col_units
# where clause
idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables)
sql['where'] = where_conds
# group by clause
idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables)
sql['groupBy'] = group_col_units
# having clause
idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables)
sql['having'] = having_conds
# order by clause
idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables)
sql['orderBy'] = order_col_units
# limit clause
idx, limit_val = parse_limit(toks, idx)
sql['limit'] = limit_val
idx = skip_semicolon(toks, idx)
if isBlock:
assert toks[idx] == ')'
idx += 1 # skip ')'
idx = skip_semicolon(toks, idx)
# intersect/union/except clause
for op in SQL_OPS: # initialize IUE
sql[op] = None
if idx < len_ and toks[idx] in SQL_OPS:
sql_op = toks[idx]
idx += 1
idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema)
sql[sql_op] = IUE_sql
return idx, sql
def load_data(fpath):
with open(fpath) as f:
data = json.load(f)
return data
def get_sql(schema, query):
toks = tokenize(query)
tables_with_alias = get_tables_with_alias(schema.schema, toks)
_, sql = parse_sql(toks, 0, tables_with_alias, schema)
#print(sql)
return sql
def skip_semicolon(toks, start_idx):
idx = start_idx
while idx < len(toks) and toks[idx] == ";":
idx += 1
return idx

View File

@@ -0,0 +1,6 @@
pysqlite3==0.5.2
PyJWT==2.8.0
PyYAML==6.0.1
sqlparse==0.4.4
nltk==3.8.1

View File

@@ -91,6 +91,7 @@ public class ChatDemoLoader implements CommandLineRunner {
addAgent1();
addAgent2();
addAgent3();
addAgent4();
addSampleChats();
addSampleChats2();
updateQueryScore(1);
@@ -508,6 +509,26 @@ public class ChatDemoLoader implements CommandLineRunner {
agentService.createAgent(agent, User.getFakeUser());
}
private void addAgent4() {
Agent agent = new Agent();
agent.setId(4);
agent.setName("DuSQL 互联网企业");
agent.setDescription("DuSQL");
agent.setStatus(1);
agent.setEnableSearch(1);
agent.setExamples(Lists.newArrayList());
AgentConfig agentConfig = new AgentConfig();
LLMParserTool llmParserTool = new LLMParserTool();
llmParserTool.setId("1");
llmParserTool.setType(AgentToolType.NL2SQL_LLM);
llmParserTool.setModelIds(Lists.newArrayList(9L, 10L, 11L, 12L));
agentConfig.getTools().add(llmParserTool);
agent.setAgentConfig(JSONObject.toJSONString(agentConfig));
agentService.createAgent(agent, User.getFakeUser());
}
private void updateQueryScore(Integer queryId) {
chatService.updateFeedback(queryId, 5, "");
}

View File

@@ -0,0 +1,292 @@
package com.tencent.supersonic;
import com.google.common.collect.Lists;
import com.tencent.supersonic.auth.api.authentication.pojo.User;
import com.tencent.supersonic.common.pojo.JoinCondition;
import com.tencent.supersonic.common.pojo.ModelRela;
import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum;
import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum;
import com.tencent.supersonic.headless.api.pojo.enums.DimensionType;
import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType;
import com.tencent.supersonic.headless.api.pojo.Dim;
import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams;
import com.tencent.supersonic.headless.api.pojo.Identify;
import com.tencent.supersonic.headless.api.pojo.Measure;
import com.tencent.supersonic.headless.api.pojo.ModelDetail;
import com.tencent.supersonic.headless.api.pojo.request.DomainReq;
import com.tencent.supersonic.headless.api.pojo.request.MetricReq;
import com.tencent.supersonic.headless.api.pojo.request.ModelReq;
import com.tencent.supersonic.headless.api.pojo.response.MetricResp;
import com.tencent.supersonic.headless.server.service.DomainService;
import com.tencent.supersonic.headless.server.service.MetricService;
import com.tencent.supersonic.headless.server.service.ModelRelaService;
import com.tencent.supersonic.headless.server.service.ModelService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Component
@Slf4j
public class DuSQLDemoDataLoader {
private User user = User.getFakeUser();
@Autowired
private DomainService domainService;
@Autowired
private ModelService modelService;
@Autowired
private ModelRelaService modelRelaService;
@Autowired
private MetricService metricService;
public void doRun() {
try {
addDomain();
addModel_1();
addModel_2();
addModel_3();
addModel_4();
addModelRela_1();
addModelRela_2();
addModelRela_3();
addModelRela_4();
} catch (Exception e) {
log.error("Failed to add bench mark demo data", e);
}
}
public void addDomain() {
DomainReq domainReq = new DomainReq();
domainReq.setName("DuSQL_互联网企业");
domainReq.setBizName("internet");
domainReq.setParentId(0L);
domainReq.setViewers(Arrays.asList("admin", "tom", "jack"));
domainReq.setViewOrgs(Collections.singletonList("1"));
domainReq.setAdmins(Collections.singletonList("admin"));
domainReq.setAdminOrgs(Collections.emptyList());
domainService.createDomain(domainReq, user);
}
//9
public void addModel_1() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("公司");
modelReq.setBizName("company");
modelReq.setDescription("公司");
modelReq.setDatabaseId(1L);
modelReq.setDomainId(4L);
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
dimensions.add(new Dim("公司名称", "company_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("总部地点", "headquarter_address", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("公司成立时间", "company_established_time", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("创始人", "founder", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("首席执行官", "ceo", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>();
identifiers.add(new Identify("公司id", IdentifyType.primary.name(), "company_id"));
modelDetail.setIdentifiers(identifiers);
List<Measure> measures = new ArrayList<>();
measures.add(new Measure("年营业额", "annual_turnover", AggOperatorEnum.SUM.name(), 1));
Measure measure = new Measure("员工数", "employee_count", AggOperatorEnum.SUM.name(), 1);
measures.add(measure);
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,company_id,company_name,headquarter_address,"
+ "company_established_time,founder,ceo,annual_turnover,employee_count FROM company");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
}
// 10
public void addModel_2() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("品牌");
modelReq.setBizName("brand");
modelReq.setDescription("品牌");
modelReq.setDatabaseId(1L);
modelReq.setDomainId(4L);
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
dimensions.add(new Dim("品牌名称", "brand_name", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("品牌成立时间", "brand_established_time", DimensionType.categorical.name(), 1));
dimensions.add(new Dim("法定代表人", "legal_representative", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>();
identifiers.add(new Identify("品牌id", IdentifyType.primary.name(), "brand_id"));
identifiers.add(new Identify("公司id", IdentifyType.foreign.name(), "company_id"));
modelDetail.setIdentifiers(identifiers);
List<Measure> measures = new ArrayList<>();
measures.add(new Measure("注册资本", "registered_capital", AggOperatorEnum.SUM.name(), 1));
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,brand_id,brand_name,brand_established_time,"
+ "company_id,legal_representative,registered_capital FROM brand");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
}
// 11
public void addModel_3() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("公司各品牌收入排名");
modelReq.setBizName("company_revenue");
modelReq.setDescription("公司各品牌收入排名");
modelReq.setDatabaseId(1L);
modelReq.setDomainId(4L);
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>();
identifiers.add(new Identify("公司id", IdentifyType.foreign.name(), "company_id"));
identifiers.add(new Identify("品牌id", IdentifyType.foreign.name(), "brand_id"));
modelDetail.setIdentifiers(identifiers);
List<Measure> measures = new ArrayList<>();
Measure measure = new Measure("营收占比", "revenue_proportion", AggOperatorEnum.SUM.name(), 1);
measures.add(measure);
measures.add(new Measure("利润占比", "profit_proportion", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("支出占比", "expenditure_proportion", AggOperatorEnum.SUM.name(), 1));
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,company_id,brand_id,revenue_proportion,"
+ "profit_proportion,expenditure_proportion FROM company_revenue");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
MetricResp metricResp = metricService.getMetric(13L, user);
MetricReq metricReq = new MetricReq();
BeanUtils.copyProperties(metricResp, metricReq);
metricReq.setAlias("收入比例");
metricService.updateMetric(metricReq, user);
}
// 12
public void addModel_4() throws Exception {
ModelReq modelReq = new ModelReq();
modelReq.setName("公司品牌历年收入");
modelReq.setBizName("company_brand_revenue");
modelReq.setDescription("公司品牌历年收入");
modelReq.setDatabaseId(1L);
modelReq.setDomainId(4L);
modelReq.setViewers(Arrays.asList("admin", "tom", "jack"));
modelReq.setViewOrgs(Collections.singletonList("1"));
modelReq.setAdmins(Collections.singletonList("admin"));
modelReq.setAdminOrgs(Collections.emptyList());
ModelDetail modelDetail = new ModelDetail();
List<Dim> dimensions = new ArrayList<>();
Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0);
DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none");
dimension1.setTypeParams(dimensionTimeTypeParams);
dimensions.add(dimension1);
dimensions.add(new Dim("年份", "year_time", DimensionType.categorical.name(), 1));
modelDetail.setDimensions(dimensions);
List<Identify> identifiers = new ArrayList<>();
identifiers.add(new Identify("品牌id", IdentifyType.foreign.name(), "brand_id"));
modelDetail.setIdentifiers(identifiers);
List<Measure> measures = new ArrayList<>();
measures.add(new Measure("营收", "revenue", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("利润", "profit", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("营收同比增长", "revenue_growth_year_on_year", AggOperatorEnum.SUM.name(), 1));
measures.add(new Measure("利润同比增长", "profit_growth_year_on_year", AggOperatorEnum.SUM.name(), 1));
modelDetail.setMeasures(measures);
modelDetail.setQueryType("sql_query");
modelDetail.setSqlQuery("SELECT imp_date,year_time,brand_id,revenue,profit,"
+ "revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue");
modelReq.setModelDetail(modelDetail);
modelService.createModel(modelReq, user);
}
public void addModelRela_1() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(4L);
modelRelaReq.setFromModelId(9L);
modelRelaReq.setToModelId(10L);
modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_2() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(4L);
modelRelaReq.setFromModelId(9L);
modelRelaReq.setToModelId(11L);
modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_3() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("brand_id", "brand_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(4L);
modelRelaReq.setFromModelId(10L);
modelRelaReq.setToModelId(11L);
modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, user);
}
public void addModelRela_4() {
List<JoinCondition> joinConditions = Lists.newArrayList();
joinConditions.add(new JoinCondition("brand_id", "brand_id", FilterOperatorEnum.EQUALS));
ModelRela modelRelaReq = new ModelRela();
modelRelaReq.setDomainId(4L);
modelRelaReq.setFromModelId(10L);
modelRelaReq.setToModelId(12L);
modelRelaReq.setJoinType("inner join");
modelRelaReq.setJoinConditions(joinConditions);
modelRelaService.save(modelRelaReq, user);
}
}

View File

@@ -25,6 +25,9 @@ public class HeadlessDemoLoader implements CommandLineRunner {
@Autowired
private BenchMarkDemoDataLoader benchMarkDemoLoader;
@Autowired
private DuSQLDemoDataLoader duSQLDemoDataLoader;
@Value("${demo.enabled:false}")
private boolean demoEnabled;
@@ -36,6 +39,7 @@ public class HeadlessDemoLoader implements CommandLineRunner {
}
modelDataDemoLoader.doRun();
benchMarkDemoLoader.doRun();
duSQLDemoDataLoader.doRun();
isLoad = true;
}
@@ -50,4 +54,4 @@ public class HeadlessDemoLoader implements CommandLineRunner {
return isLoad;
}
}
}

View File

@@ -91,4 +91,4 @@ inMemoryEmbeddingStore:
query:
optimizer:
enable: true
enable: true

View File

@@ -0,0 +1,8 @@
阿里云 _10_20 5
天猫 _10_20 5
腾讯游戏 _10_20 5
度小满 _10_20 5
京东金融 _10_20 5

View File

@@ -0,0 +1,8 @@
张勇 _10_22 5
马化腾 _10_22 5
朱光 _10_22 5
刘强东 _10_22 5

View File

@@ -0,0 +1,5 @@
百度集团 _9_15 5
阿里巴巴集团 _9_15 5
深圳市腾讯计算机系统有限公司 _9_15 5
北京京东世纪贸易有限公司 _9_15 5
网易公司 _9_15 5

View File

@@ -0,0 +1,4 @@
北京 _9_16 5
杭州 _9_16 5
深圳 _9_16 5

View File

@@ -0,0 +1,7 @@
李彦宏 _9_18 5
马云 _9_18 5
马化腾 _9_18 5
刘强东 _9_18 5
丁磊 _9_18 5

View File

@@ -0,0 +1,7 @@
李彦宏 _9_19 5
张勇 _9_19 5
刘炽平 _9_19 5
刘强东 _9_19 5
丁磊 _9_19 5

View File

@@ -1111,4 +1111,29 @@ MERGE INTO song(imp_date,song_name,artist_name,country,f_id,g_name,rating,langua
MERGE INTO song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'打败它','Michel','英国',5,'流行',8,'英文','17-MAR-2002',720);
MERGE INTO song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'阿杰伊阿卡什','Topu','印度',6,'现代',10,'孟加拉语','27-MAR-2004',320);
insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','百度集团','北京','2000','李彦宏','李彦宏',102300000000,40000);
insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_132','阿里巴巴集团','杭州','1999年','马云','张勇',376800000000,103699);
insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_133','深圳市腾讯计算机系统有限公司','深圳','1998','马化腾','刘炽平',321600000000,56310);
insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','北京京东世纪贸易有限公司','北京','1998','刘强东','刘强东',28800000000,179000);
insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','网易公司','杭州','1997','丁磊','丁磊',67500000000,20000);
insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_136','阿里云','2009年9月10日','item_enterprise_13_134','张勇',50000000);
insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_137','天猫','2012年1月11日','item_enterprise_13_134','张勇',100000000);
insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_138','腾讯游戏','2003','item_enterprise_13_131','马化腾',50000000);
insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_139','度小满','2018','item_enterprise_13_132','朱光',100000000);
insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_140','京东金融','2017','item_enterprise_13_134','刘强东',100000000);
insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','item_enterprise_13_139',10,10,30);
insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','item_enterprise_13_138',80,80,60);
insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','item_enterprise_13_139',80,80,60);
insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','item_enterprise_13_137',80,80,60);
insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','item_enterprise_13_137',10,10,30);
insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2018','item_enterprise_13_138',500000000,-300000000,10,-10);
insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2019','item_enterprise_13_136',100000000000,50000000000,100,50);
insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_137',100000000000,50000000000,100,-10);
insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2018','item_enterprise_13_139',500000000,50000000000,10,50);
insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_138',100000000000,-300000000,10,50);
-- benchmark

View File

@@ -455,6 +455,51 @@ CREATE TABLE IF NOT EXISTS `song` (
);
COMMENT ON TABLE song IS 'song';
CREATE TABLE IF NOT EXISTS `company` (
`imp_date` varchar(50) ,
`company_id` varchar(50) NOT NULL ,
`company_name` varchar(50) NOT NULL ,
`headquarter_address` varchar(50) NOT NULL ,
`company_established_time` varchar(20) NOT NULL ,
`founder` varchar(20) NOT NULL ,
`ceo` varchar(20) NOT NULL ,
`annual_turnover` bigint(15) ,
`employee_count` int(7) ,
PRIMARY KEY (`company_id`)
);
CREATE TABLE IF NOT EXISTS `brand` (
`imp_date` varchar(50) ,
`brand_id` varchar(50) NOT NULL ,
`brand_name` varchar(50) NOT NULL ,
`brand_established_time` varchar(20) NOT NULL ,
`company_id` varchar(50) NOT NULL ,
`legal_representative` varchar(20) NOT NULL ,
`registered_capital` bigint(15) ,
PRIMARY KEY (`brand_id`)
);
CREATE TABLE IF NOT EXISTS `company_revenue` (
`imp_date` varchar(50) ,
`company_id` varchar(50) NOT NULL ,
`brand_id` varchar(50) NOT NULL ,
`revenue_proportion` double NOT NULL,
`profit_proportion` double NOT NULL ,
`expenditure_proportion` double NOT NULL
);
CREATE TABLE IF NOT EXISTS `company_brand_revenue` (
`imp_date` varchar(50) ,
`year_time` varchar(10) NOT NULL ,
`brand_id` varchar(50) NOT NULL ,
`revenue` bigint(15) NOT NULL,
`profit` bigint(15) NOT NULL ,
`revenue_growth_year_on_year` double NOT NULL ,
`profit_growth_year_on_year` double NOT NULL
);
CREATE TABLE IF NOT EXISTS s2_sys_parameter
(
id INT PRIMARY KEY AUTO_INCREMENT,
@@ -498,4 +543,4 @@ CREATE TABLE IF NOT EXISTS `s2_app` (
created_by VARCHAR(255),
updated_at TIMESTAMP,
updated_by VARCHAR(255)
);
);

View File

@@ -1,2 +1,2 @@
root=.
CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_1_3.txt;data/dictionary/custom/benchmark_cspider.txt;
CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_9_15.txt;data/dictionary/custom/DimValue_9_16.txt;data/dictionary/custom/DimValue_9_18.txt;data/dictionary/custom/DimValue_9_19.txt;data/dictionary/custom/DimValue_10_20.txt;data/dictionary/custom/DimValue_10_22.txt;