diff --git a/chat/python/config/run_config.ini b/chat/python/config/run_config.ini index 48f5ac32e..d1c0c143f 100644 --- a/chat/python/config/run_config.ini +++ b/chat/python/config/run_config.ini @@ -25,4 +25,4 @@ LLM_PROVIDER_NAME = openai [LLMModel] MODEL_NAME = gpt-3.5-turbo-16k OPENAI_API_KEY = YOUR_API_KEY -TEMPERATURE = 0.0 \ No newline at end of file +TEMPERATURE = 0.0 diff --git a/chat/python/requirements.txt b/chat/python/requirements.txt index f50323794..5b5da5648 100644 --- a/chat/python/requirements.txt +++ b/chat/python/requirements.txt @@ -6,4 +6,4 @@ tiktoken==0.3.3 uvicorn[standard]==0.21.1 pandas==1.5.3 loguru==0.7.2 -sqlglot==19.5.1 \ No newline at end of file +sqlglot==19.5.1 diff --git a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoProcessor.java b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoProcessor.java index e7d6b2a42..9d948ad97 100644 --- a/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoProcessor.java +++ b/chat/server/src/main/java/com/tencent/supersonic/chat/server/processor/parse/ParseInfoProcessor.java @@ -102,8 +102,18 @@ public class ParseInfoProcessor implements ParseResultProcessor { private Set getElements(Set modelIds, List allFields, List elements) { return elements.stream() - .filter(schemaElement -> modelIds.contains(schemaElement.getModel()) && allFields.contains( - schemaElement.getName()) + .filter(schemaElement -> { + if (CollectionUtils.isEmpty(schemaElement.getAlias())) { + return modelIds.contains(schemaElement.getModel()) && allFields.contains( + schemaElement.getName()); + } + Set allFieldsSet = new HashSet<>(allFields); + Set aliasSet = new HashSet<>(schemaElement.getAlias()); + List intersection = allFieldsSet.stream() + .filter(aliasSet::contains).collect(Collectors.toList()); + return modelIds.contains(schemaElement.getModel()) && (allFields.contains( + schemaElement.getName()) || !CollectionUtils.isEmpty(intersection)); + } ).collect(Collectors.toSet()); } @@ -208,4 +218,4 @@ public class ParseInfoProcessor implements ParseResultProcessor { (value1, value2) -> value2)); } -} \ No newline at end of file +} diff --git a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java index 6782c8208..7f665c5a0 100644 --- a/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java +++ b/common/src/main/java/com/tencent/supersonic/common/util/jsqlparser/SqlParserReplaceHelper.java @@ -23,6 +23,7 @@ import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals; import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; +import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.GroupByElement; import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.OrderByElement; @@ -231,6 +232,9 @@ public class SqlParserReplaceHelper { if (!CollectionUtils.isEmpty(joins)) { for (Join join : joins) { join.getOnExpression().accept(visitor); + if (!(join.getRightItem() instanceof SubSelect)) { + continue; + } SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); List plainSelectList = new ArrayList<>(); plainSelectList.add((PlainSelect) subSelectBody); @@ -414,12 +418,17 @@ public class SqlParserReplaceHelper { List joins = painSelect.getJoins(); if (!CollectionUtils.isEmpty(joins)) { for (Join join : joins) { - SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); - List plainSelectList = new ArrayList<>(); - plainSelectList.add((PlainSelect) subSelectBody); - List subPlainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); - for (PlainSelect subPlainSelect : subPlainSelects) { - subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); + if (join.getRightItem() instanceof SubSelect) { + SelectBody subSelectBody = ((SubSelect) join.getRightItem()).getSelectBody(); + List plainSelectList = new ArrayList<>(); + plainSelectList.add((PlainSelect) subSelectBody); + List subPlainSelects = SqlParserSelectHelper.getPlainSelects(plainSelectList); + for (PlainSelect subPlainSelect : subPlainSelects) { + subPlainSelect.getFromItem().accept(new TableNameReplaceVisitor(tableName)); + } + } else if (join.getRightItem() instanceof Table) { + Table table = (Table) join.getRightItem(); + table.setName(tableName); } } } diff --git a/evaluation/README_CN.md b/evaluation/README_CN.md new file mode 100644 index 000000000..6b2784ac2 --- /dev/null +++ b/evaluation/README_CN.md @@ -0,0 +1,5 @@ +# 评测流程 + +1、正常启动项目(必须包括LLM服务) +2、将要评测问题放到evalution/data目录下,如:internet.txt;将要评测问题对应的SQL也放到evalution/data目录下,如:gold_example_dusql.txt。 +3、执行evalution.sh脚本,主要包括构建表数据、获取模型预测结果,执行对比逻辑。可以在命令行看到执行准确率,错误case会写到同目录的eval.json文件中。 diff --git a/evaluation/build_pred_result.py b/evaluation/build_pred_result.py new file mode 100644 index 000000000..7c85bb52a --- /dev/null +++ b/evaluation/build_pred_result.py @@ -0,0 +1,74 @@ +import requests +import logging +import json +import jwt +import time +import os +import yaml + +def read_query(input_path): + result=[] + with open(input_path, "r") as f: + for line in f.readlines(): + line = line.strip('\n') + result.append(line) + return result +def write_sql(output_path,result): + file = open(output_path, mode='a') + file.writelines(result) + file.close() +def get_pred_sql(query,url,agentId,chatId,authorization,default_sql): + url=url+"/api/chat/query/parse" + data = {"agentId": agentId, "chatId":chatId,"queryText":query} + header = {} + header["Authorization"] =authorization + try: + result = requests.post(url=url, headers=header, json=data) + if result.status_code == 200: + data = result.json()["data"] + selectedParses = data["selectedParses"] + if selectedParses is not None and len(selectedParses) > 0: + querySQL = selectedParses[0]["sqlInfo"]["querySQL"] + querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "") + return querySQL+'\n' + return default_sql+'\n' + except Exception as e: + print(url) + print(result.json()) + print(e) + logging.info(e) + return default_sql+'\n' +def get_authorization(): + exp = time.time() + 1000 + token= jwt.encode({"token_user_name": "admin","exp": exp}, "secret", algorithm="HS512") + return "Bearer "+token + +def get_pred_result(): + current_directory = os.path.dirname(os.path.abspath(__file__)) + config_file=current_directory+"/config/config.yaml" + with open(config_file, 'r') as file: + config = yaml.safe_load(file) + input_path=current_directory+"/data/"+config["domain"]+".txt" + pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt" + pred_sql_exist=os.path.exists(pred_sql_path) + if pred_sql_exist: + os.remove(pred_sql_path) + print("pred_sql_path removed!") + agent_id=config["agent_id"] + chat_id=config["chat_id"] + url=config["url"] + authorization=get_authorization() + print(input_path) + print(pred_sql_path) + questions=read_query(input_path) + pred_sql_list=[] + default_sql="select * from tablea " + for i in range(0,len(questions)): + pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql) + pred_sql_list.append(pred_sql) + write_sql(pred_sql_path, pred_sql_list) + +if __name__ == "__main__": + print("pred") + + diff --git a/evaluation/build_tables.py b/evaluation/build_tables.py new file mode 100644 index 000000000..6568744da --- /dev/null +++ b/evaluation/build_tables.py @@ -0,0 +1,340 @@ +import sqlite3 +import os +import datetime +import yaml + +def build_internet(path,day): + imp_date=(datetime.datetime.now()+datetime.timedelta(days=day)).strftime("%Y-%m-%d") + print(imp_date) + conn = sqlite3.connect(path+'/internet.db') + cursor = conn.cursor() + create_table_query = ''' + CREATE TABLE IF NOT EXISTS company ( + `imp_date` varchar(50) , + `company_id` varchar(50) NOT NULL , + `company_name` varchar(50) NOT NULL , + `headquarter_address` varchar(50) NOT NULL , + `company_established_time` varchar(20) NOT NULL , + `founder` varchar(20) NOT NULL , + `ceo` varchar(20) NOT NULL , + `annual_turnover` bigint(15) , + `employee_count` int(7) , + PRIMARY KEY (`company_id`) + ) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO company (imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) + VALUES (?, ?, ?,?, ?, ?,?, ?, ?) + ''' + data = [ + (imp_date,"item_enterprise_13_131","百度集团","北京","2000","李彦宏","李彦宏",102300000000,40000), + (imp_date,"item_enterprise_13_132","阿里巴巴集团","杭州","1999年","马云","张勇",376800000000,103699), + (imp_date,"item_enterprise_13_133","深圳市腾讯计算机系统有限公司","深圳","1998","马化腾","刘炽平",321600000000,56310), + (imp_date,"item_enterprise_13_134","北京京东世纪贸易有限公司","北京","1998","刘强东","刘强东",28800000000,179000), + (imp_date,"item_enterprise_13_135","网易公司","杭州","1997","丁磊","丁磊",67500000000,20000) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS brand ( + `imp_date` varchar(50) , + `brand_id` varchar(50) NOT NULL , + `brand_name` varchar(50) NOT NULL , + `brand_established_time` varchar(20) NOT NULL , + `company_id` varchar(50) NOT NULL , + `legal_representative` varchar(20) NOT NULL , + `registered_capital` bigint(15) , + PRIMARY KEY (`brand_id`) + ) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO brand (imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) + VALUES (?, ?, ?,?, ?, ?,?) + ''' + data = [ + (imp_date,"item_enterprise_13_136","阿里云","2009年9月10日","item_enterprise_13_134","张勇",50000000), + (imp_date,"item_enterprise_13_137","天猫","2012年1月11日","item_enterprise_13_134","张勇",100000000), + (imp_date,"item_enterprise_13_138","腾讯游戏","2003","item_enterprise_13_131","马化腾",50000000), + (imp_date,"item_enterprise_13_139","度小满","2018","item_enterprise_13_132","朱光",100000000), + (imp_date,"item_enterprise_13_140","京东金融","2017","item_enterprise_13_134","刘强东",100000000) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS company_revenue ( + `imp_date` varchar(50) , + `company_id` varchar(50) NOT NULL , + `brand_id` varchar(50) NOT NULL , + `revenue_proportion` double NOT NULL, + `profit_proportion` double NOT NULL , + `expenditure_proportion` double NOT NULL + ) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO company_revenue (imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) + VALUES (?, ?, ?,?, ?, ?) + ''' + data = [ + (imp_date,"item_enterprise_13_131","item_enterprise_13_139",0.1,0.1,0.3), + (imp_date,"item_enterprise_13_134","item_enterprise_13_138",0.8,0.8,0.6), + (imp_date,"item_enterprise_13_135","item_enterprise_13_139",0.8,0.8,0.6), + (imp_date,"item_enterprise_13_131","item_enterprise_13_137",0.8,0.8,0.6), + (imp_date,"item_enterprise_13_135","item_enterprise_13_137",0.1,0.1,0.3) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS company_brand_revenue ( + `imp_date` varchar(50) , + `year_time` varchar(10) NOT NULL , + `brand_id` varchar(50) NOT NULL , + `revenue` bigint(15) NOT NULL, + `profit` bigint(15) NOT NULL , + `revenue_growth_year_on_year` double NOT NULL , + `profit_growth_year_on_year` double NOT NULL + ) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO company_brand_revenue (imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) + VALUES (?, ?, ?,?, ?, ?,?) + ''' + data = [ + (imp_date, "2018", "item_enterprise_13_138", 500000000, -300000000, 0.1, -0.1), + (imp_date, "2019", "item_enterprise_13_136", 100000000000, 50000000000, 1, 0.5), + (imp_date, "2018", "item_enterprise_13_137", 100000000000, 50000000000, 1, -0.1), + (imp_date, "2018", "item_enterprise_13_139", 500000000, 50000000000, 0.1, 0.5), + (imp_date, "2018", "item_enterprise_13_138", 100000000000, -300000000, 0.1, 0.5) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + conn.close() +def build_china_travel_agency(path,day): + imp_date=(datetime.datetime.now()+datetime.timedelta(days=day)).strftime("%Y-%m-%d") + print(imp_date) + conn = sqlite3.connect(path+'/china_travel_agency.db') + cursor = conn.cursor() + create_table_query = ''' + CREATE TABLE IF NOT EXISTS `travel_agency` ( + `imp_date` varchar(50) , + `travel_agency_id` varchar(50) NOT NULL, + `travel_agency_name` varchar(50) NOT NULL, + `travel_agency_level` varchar(50) NOT NULL, + `number_countrie_outbound_travel` int(7) , + `number_domestic_tourist_cities` int(7) , + `number_outbound_travel_routes` int(7) , + `number_domestic_travel_routes` int(7) , + `asia_ranking` int(7) , + `number_overseas_tourists_received` int(7) , + `number_overseas_companies` int(7) , + `number_holding_subsidiaries` int(7) , + `number_traveling_salesmen_business_relationships` int(7) , + `number_duty_free_shops` int(7) , + PRIMARY KEY (`travel_agency_id`) +) + + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO travel_agency (imp_date, travel_agency_id,travel_agency_name,travel_agency_level,number_countrie_outbound_travel, + number_domestic_tourist_cities,number_outbound_travel_routes,number_domestic_travel_routes, + asia_ranking,number_overseas_tourists_received,number_overseas_companies,number_holding_subsidiaries, + number_traveling_salesmen_business_relationships,number_duty_free_shops) + VALUES (?, ?, ?,?, ?, ?,?, ?, ?,?, ?, ?, ?, ?) + ''' + data = [ + (imp_date,"item_enterprise_7_56","中国国旅","3A",50,30,500,1000,1,10000000,5,70,5000,100), + (imp_date,"item_enterprise_7_57","中旅总社","4A",90,100,1000,2000,100,20000000,20,120,8000,180), + (imp_date,"item_enterprise_7_58","中青旅控股股份有限公司","5A",50,30,500,2000,1,10000000,5,70,5000,100), + (imp_date,"item_enterprise_7_59","中国康辉旅游集团有限公司","5A",50,30,1000,1000,100,10000000,5,70,8000,100), + (imp_date,"item_enterprise_7_60","众信旅游集团股份有限公司","5A",90,30,1000,2000,1,20000000,20,70,8000,180) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS `outbound_travel_routes` ( + `imp_date` varchar(50) , + `outbound_route_id` varchar(50) NOT NULL , + `outbound_route_name` varchar(50) NOT NULL , + `travel_agency_id` varchar(50) NOT NULL , + `outbound_departure_city` varchar(50) NOT NULL , + `outbound_days` bigint(5) , + `adult_price` bigint(5) , + `child_price` bigint(5) , + `countries` bigint(5) , + `attractions` bigint(5) , + `total_ticket_price` bigint(5) , + PRIMARY KEY (`outbound_route_id`) +) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO outbound_travel_routes (imp_date, outbound_route_id,outbound_route_name,travel_agency_id, +outbound_departure_city,outbound_days,adult_price,child_price,countries,attractions,total_ticket_price) + VALUES (?, ?, ?,?, ?, ?,?,?, ?, ?,?) + ''' + data = [ + (imp_date,"item_enterprise_7_61","德法瑞意深度纵览一价无忧","item_enterprise_7_59","北京",12,10900,8900,5,15,750), + (imp_date,"item_enterprise_7_62","意大利全景+西西里精华深度纵览","item_enterprise_7_59","天津",20,18900,15900,10,25,2500), + (imp_date,"item_enterprise_7_63","悦选意大利经典大城小镇书香之旅","item_enterprise_7_57","上海",20,10900,8900,5,15,2500), + (imp_date,"item_enterprise_7_64","新西兰南岛双冰川双峡湾深度纯净之旅","item_enterprise_7_59","哈尔滨",12,18900,8900,5,15,2500), + (imp_date,"item_enterprise_7_65","英国+爱尔兰+威尔士精选之旅","item_enterprise_7_57","深圳",12,18900,15900,10,15,750) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS `country_outbound_travel` ( + `imp_date` varchar(50) , + `outbound_travel_route_id` varchar(50) NOT NULL , + `nation` varchar(20) NOT NULL , + `travel_days` int(6) NOT NULL , + `outbound_number_attractions` int(6) NOT NULL +) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO country_outbound_travel (imp_date, outbound_travel_route_id,nation,travel_days,outbound_number_attractions) + VALUES (?, ?, ?,?, ?) + ''' + data = [ + (imp_date,"item_enterprise_7_64","意大利",2,3), + (imp_date,"item_enterprise_7_63","法国",4,5), + (imp_date,"item_enterprise_7_62","德国",4,5), + (imp_date,"item_enterprise_7_65","瑞士",4,3), + (imp_date,"item_enterprise_7_61","英格兰",4,3) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS `domestic_travel_routes` ( + `imp_date` varchar(50) , + `domestic_travel_route_id` varchar(50) NOT NULL , + `domestic_travel_route_name` varchar(50) NOT NULL , + `travel_agency_id` varchar(50) NOT NULL , + `domestic_departure_city` varchar(50) NOT NULL , + `domestic_days` int(5) , + `presale_price` int(8) NOT NULL , + `tour_price` int(8) NOT NULL , + `number_people_group` int(8) NOT NULL , + `personal_price` int(8) NOT NULL , + `domestic_number_attractions` int(6) NOT NULL , + PRIMARY KEY (`domestic_travel_route_id`) + ) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO domestic_travel_routes (imp_date, domestic_travel_route_id,domestic_travel_route_name,travel_agency_id,domestic_departure_city,domestic_days,presale_price,tour_price,number_people_group,personal_price,domestic_number_attractions) + VALUES (?, ?, ?,?, ?, ?,?,?, ?, ?,?) + ''' + data = [ + (imp_date,"item_enterprise_7_66","桂林深度跟团游","item_enterprise_7_60",'北京',4,2500,2000,2,3000,10), + (imp_date,"item_enterprise_7_67","厦门休闲游","item_enterprise_7_56",'天津',8,6500,5000,5,7000,20), + (imp_date,"item_enterprise_7_68","重庆红色之旅","item_enterprise_7_60",'上海',4,6500,2000,5,3000,20), + (imp_date,"item_enterprise_7_69","云南古城游","item_enterprise_7_59",'哈尔滨',4,6500,2000,5,7000,20), + (imp_date,"item_enterprise_7_70","上海时尚游","item_enterprise_7_59",'深圳',4,6500,2000,5,7000,10) + ] + cursor.executemany(insert_data_query, data) + conn.commit() + + + create_table_query = ''' + CREATE TABLE IF NOT EXISTS `cruise_route` ( + `imp_date` varchar(50) , + `cruise_route_id` varchar(50) NOT NULL , + `cruise_route_name` varchar(50) NOT NULL , + `travel_agency_id` varchar(50) NOT NULL , + `cruise_departure_city` varchar(50) NOT NULL , + `cruise_days` int(5) , + `interior_cabin_price` int(8) NOT NULL , + `sea_view_room_price` int(8) NOT NULL , + `balcony_room_price` int(8) NOT NULL , + `sailing_area` varchar(50) NOT NULL , + `cruise_line` varchar(50) NOT NULL , + PRIMARY KEY (`cruise_route_id`) + ) + ''' + cursor.execute(create_table_query) + insert_data_query = ''' + INSERT INTO cruise_route (imp_date, cruise_route_id,cruise_route_name,travel_agency_id,cruise_departure_city,cruise_days,interior_cabin_price,sea_view_room_price,balcony_room_price,sailing_area,cruise_line) + VALUES (?, ?, ?,?, ?, ?,?,?, ?, ?,?) + ''' + data = [ + (imp_date,"item_enterprise_7_71","南极摄影旅游","item_enterprise_7_57",'大连',6,4399,4799,5299,"日本航线","皇家加勒比国际游轮"), + (imp_date,"item_enterprise_7_72","地中海巡游","item_enterprise_7_58",'天津',10,6399,6799,7399,"韩国航线","海洋亚特兰蒂游轮"), + (imp_date,"item_enterprise_7_73","超凡体验来自未来的游轮","item_enterprise_7_60",'上海',10,4399,4799,5299,"南极航线","庞洛游轮"), + (imp_date,"item_enterprise_7_74","超凡体验来自未来的游轮","item_enterprise_7_60",'深圳',10,6399,6799,5299,"南极航线","庞洛游轮"), + (imp_date,"item_enterprise_7_75","超凡体验来自未来的游轮","item_enterprise_7_60",'天津',10,6399,4799,5299,"韩国航线","海洋亚特兰蒂游轮") + ] + cursor.executemany(insert_data_query, data) + conn.commit() + conn.close() +def build_table(): + current_directory = os.path.dirname(os.path.abspath(__file__)) + config_file=current_directory+"/config/config.yaml" + with open(config_file, 'r') as file: + config = yaml.safe_load(file) + db_path=current_directory+"/data/" + db_file=db_path+config["domain"]+".db" + db_exist=os.path.exists(db_file) + if db_exist: + os.remove(db_file) + print("db_file removed!") + print(db_file) + build_internet(db_path,-1) +if __name__ == '__main__': + current_directory = os.path.dirname(os.path.abspath(__file__)) + config_file=current_directory+"/config/config.yaml" + with open(config_file, 'r') as file: + config = yaml.safe_load(file) + db_path=current_directory+"/data/" + db_file=db_path+config["domain"]+".db" + db_exist=os.path.exists(db_file) + if db_exist: + os.remove(db_file) + print("db_file removed!") + print(db_file) + build_internet(db_path,-1) + #build_china_travel_agency(path) + + + + + + + + # sql="SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id" + # query_sql(sql) + # sql =" select * from ( SELECT brand_name, company_name, sum(revenue_proportion), sum(annual_turnover) FROM (select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover`, `revenue_proportion`from(select `sys_imp_date`, `company_name`, `brand_name`, `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`from(select `annual_turnover` as `annual_turnover`, `revenue_proportion` as `revenue_proportion`, `sys_imp_date`, `company_name`, `brand_name`from(select `src1_brand`.`sys_imp_date` as `sys_imp_date`, `src1_company`.`annual_turnover` as `annual_turnover`, `src1_company_revenue`.`revenue_proportion` as `revenue_proportion`, `src1_brand`.`company_id` as `company_id`, `src1_company`.`company_name` as `company_name`, `src1_brand`.`brand_name` as `brand_name`, `src1_brand`.`brand_id` as `brand_id`from(select `annual_turnover` as `annual_turnover`, `imp_date` as `sys_imp_date`, `company_id`, `company_name` as `company_name`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `company_name`, `headquarter_address`, `company_established_time`, `founder`, `ceo`, `annual_turnover`, `employee_count`from`company`) as `company`) as `src1_company`inner join (select `revenue_proportion` as `revenue_proportion`, `imp_date` as `sys_imp_date`, `company_id`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `company_id`, `brand_id`, `revenue_proportion`, `profit_proportion`, `expenditure_proportion`from`company_revenue`) as `company_revenue`) as `src1_company_revenue` on `src1_company`.`company_id` = `src1_company_revenue`.`company_id`inner join (select `imp_date` as `sys_imp_date`, `company_id`, `brand_name` as `brand_name`, `brand_id`, `imp_date` as `imp_date`from(select `imp_date`, `brand_id`, `brand_name`, `brand_established_time`, `company_id`, `legal_representative`, `registered_capital`from`brand`) as `brand`) as `src1_brand` on `src1_company`.`company_id` = `src1_brand`.`company_id`) as `src11_`) as `company_company_revenue_brand_0`) as `company_company_revenue_brand_1`) t_103 WHERE sys_imp_date = '2024-01-11' GROUP BY brand_name, company_name, sys_imp_date ORDER BY sum(revenue_proportion) DESC ) a limit 1000 " + # a=query_sql(sql) + # print(a[0][0]) + + + # import pymysql + # + # db = pymysql.connect(host="11.154.212.211", user="semantic", password="semantic2023", database="internet", port=3306) + # + # # 使用cursor()方法获取操作游标 + # cursor = db.cursor() + # + # # 使用execute方法执行SQL语句 + # cursor.execute("select * from company") + # + # # 使用 fetchone() 方法获取一条数据 + # data = cursor.fetchone() + # print(data) + + diff --git a/evaluation/config/__init__.py b/evaluation/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/evaluation/config/config.yaml b/evaluation/config/config.yaml new file mode 100644 index 000000000..6aa2972a8 --- /dev/null +++ b/evaluation/config/config.yaml @@ -0,0 +1,4 @@ +chat_id: 3 +agent_id: 4 +domain: internet +url: http://localhost:9080 diff --git a/evaluation/data/gold_example_dusql.txt b/evaluation/data/gold_example_dusql.txt new file mode 100644 index 000000000..d5fc41160 --- /dev/null +++ b/evaluation/data/gold_example_dusql.txt @@ -0,0 +1,100 @@ +SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet +SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet +SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet +SELECT T3.company_name, T3.headquarter_address, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id internet +SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id WHERE T1.profit_proportion <= 0.1 internet +SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id WHERE T1.profit_proportion < 0.1 internet +SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id WHERE T1.profit_proportion > 0.1 internet +SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital >= 100000000 GROUP BY T1.brand_id ORDER BY avg(T1.revenue_proportion) DESC LIMIT 1 internet +SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital > 100000000 GROUP BY T1.brand_id ORDER BY count(*) ASC LIMIT 5 internet +SELECT T2.brand_name, avg(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, avg(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital >= 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, sum(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, max(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, max(T1.revenue_proportion), T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital > 100000000 GROUP BY T1.brand_id HAVING avg(T1.revenue_proportion) = 0.5 internet +SELECT T2.brand_name, T2.legal_representative FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id HAVING count(*) = 5 internet +SELECT T2.brand_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id HAVING avg(T1.expenditure_proportion) <= 0.45 internet +SELECT T2.brand_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id HAVING count(*) <= 5 internet +SELECT T2.legal_representative, T2.brand_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name, sum(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name, max(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING count(*) <= 5 internet +SELECT T2.legal_representative, T2.brand_name FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING avg(T1.revenue_proportion) > 0.5 internet +SELECT T2.brand_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING avg(T1.profit_proportion) <= 0.6 internet +SELECT T2.brand_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id HAVING count(*) = 5 internet +SELECT T2.brand_name, T2.legal_representative, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id ORDER BY avg(T1.profit_proportion) DESC LIMIT 1 internet +SELECT T2.brand_name, T2.legal_representative, sum(T1.revenue_proportion) FROM company_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id ORDER BY count(*) DESC LIMIT 3 internet +SELECT T3.company_name, T2.brand_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.profit_proportion ASC internet +SELECT T2.brand_name, T3.company_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.expenditure_proportion ASC LIMIT 3 internet +SELECT T2.brand_name, T3.company_name, T1.revenue_proportion, T1.profit_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.expenditure_proportion DESC LIMIT 3 internet +SELECT T2.brand_name, T3.company_name, T1.revenue_proportion FROM company_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T1.company_id = T3.company_id ORDER BY T1.profit_proportion DESC LIMIT 3 internet +SELECT brand_name FROM brand WHERE legal_representative NOT IN (SELECT legal_representative FROM brand GROUP BY legal_representative HAVING avg(registered_capital) < 1000000) internet +SELECT T1.brand_name, T2.company_name, T1.legal_representative, T2.annual_turnover FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T1.brand_name, T2.company_name, T1.registered_capital FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T1.brand_name, T2.company_name, T1.registered_capital, T2.annual_turnover FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T1.brand_name, T2.company_name, T1.registered_capital, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T1.brand_name, T2.company_name, T1.legal_representative, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T2.company_name, T2.annual_turnover FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id internet +SELECT T3.company_name, T3.headquarter_address, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet +SELECT T3.company_name, T3.annual_turnover, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet +SELECT T3.company_name, T2.brand_name, T2.legal_representative, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet +SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital >= 100000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 internet +SELECT T2.company_name, T2.annual_turnover, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 AND T2.annual_turnover <= 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 AND T2.annual_turnover < 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 AND T2.annual_turnover < 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 AND T2.annual_turnover >= 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital >= 100000000 AND T2.annual_turnover > 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 AND T2.annual_turnover > 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital > 100000000 AND T2.annual_turnover > 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital <= 100000000 AND T2.annual_turnover < 28800000000 internet +SELECT T2.company_name, T2.headquarter_address, T1.legal_representative FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T1.registered_capital < 100000000 AND T2.annual_turnover <= 28800000000 internet +SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id WHERE T1.profit <= 50000000000 internet +SELECT T3.company_name, T2.brand_name, T1.revenue FROM company_brand_revenue AS T1 JOIN brand AS T2 JOIN company AS T3 ON T1.brand_id = T2.brand_id AND T2.company_id = T3.company_id WHERE T1.revenue_growth_year_on_year >= 1 internet +SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY avg(T1.revenue_proportion) ASC LIMIT 5 internet +SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY avg(T1.revenue_proportion) DESC LIMIT 1 internet +SELECT T2.brand_name, T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital > 100000000 GROUP BY T1.brand_id ORDER BY avg(T1.revenue) DESC LIMIT 1 internet +SELECT T2.brand_name, T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id ORDER BY avg(T1.revenue) DESC LIMIT 1 internet +SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY avg(T1.registered_capital) DESC LIMIT 5 internet +SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id ORDER BY sum(T1.registered_capital) ASC LIMIT 5 internet +SELECT T2.company_name, max(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, sum(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, max(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, sum(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, avg(T1.revenue_proportion), T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet +SELECT T2.brand_name, avg(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, avg(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital >= 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, min(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, max(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id internet +SELECT T2.brand_name, avg(T1.revenue), T2.legal_representative FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id internet +SELECT T2.company_name, min(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, min(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, avg(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, max(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, sum(T1.registered_capital), T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id internet +SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id HAVING sum(T1.revenue_proportion) <= 0.5 internet +SELECT T2.company_name, T2.headquarter_address FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id HAVING sum(T1.revenue_proportion) <= 0.5 internet +SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover >= 28800000000 GROUP BY T1.company_id HAVING count(*) <= 5 internet +SELECT T2.company_name, T2.headquarter_address FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover <= 28800000000 GROUP BY T1.company_id HAVING count(*) < 5 internet +SELECT T2.company_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id HAVING count(*) > 5 internet +SELECT T2.company_name, max(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id HAVING avg(T1.profit_proportion) > 0.75 internet +SELECT T2.brand_name, max(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital <= 100000000 GROUP BY T1.brand_id HAVING sum(T1.profit_growth_year_on_year) <= 1000000 internet +SELECT T2.brand_name, min(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id WHERE T2.registered_capital < 100000000 GROUP BY T1.brand_id HAVING count(*) < 5 internet +SELECT T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover > 28800000000 GROUP BY T1.company_id HAVING count(*) < 5 internet +SELECT T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id WHERE T2.annual_turnover < 28800000000 GROUP BY T1.company_id HAVING count(*) >= 5 internet +SELECT T2.headquarter_address, T2.company_name, max(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet +SELECT T2.headquarter_address, T2.company_name, min(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet +SELECT T2.headquarter_address, T2.company_name, sum(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet +SELECT T2.headquarter_address, T2.company_name, avg(T1.revenue_proportion) FROM company_revenue AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet +SELECT T2.legal_representative, T2.brand_name, avg(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name, min(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name, sum(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.legal_representative, T2.brand_name, max(T1.revenue) FROM company_brand_revenue AS T1 JOIN brand AS T2 ON T1.brand_id = T2.brand_id GROUP BY T1.brand_id internet +SELECT T2.headquarter_address, T2.company_name, sum(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet +SELECT T2.headquarter_address, T2.company_name, max(T1.registered_capital) FROM brand AS T1 JOIN company AS T2 ON T1.company_id = T2.company_id GROUP BY T1.company_id internet diff --git a/evaluation/data/internet.txt b/evaluation/data/internet.txt new file mode 100644 index 000000000..2c7eb4ded --- /dev/null +++ b/evaluation/data/internet.txt @@ -0,0 +1,100 @@ +在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司的年营业额 +在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例 +在各公司所有品牌收入排名中,给出每一个品牌和其法人,其所在公司以及收入占该公司的总收入比例 +在各公司所有品牌收入排名中,给出每一个品牌,其所在公司以及收入占该公司的总收入比例,同时给出该公司总部所在地 +在公司各品牌收入排名的利润占比最多10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 +在公司各品牌收入排名的利润占比小于10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 +在公司各品牌收入排名的利润占比不止10%时,给出公司的名称品牌的名称并给出公司各品牌收入排名的营收占比 +注册资本不小于1亿的品牌中,哪个品牌的平均营收占比最大?并给出它的法定代表人 +注册资本大于1亿的品牌中,哪5个品牌收入最少?并给出它们的法定代表人 +找到注册资本少于一个亿的品牌及其法人,并给出对应公司的平均营收占比 +给出注册资本不少于一亿的品牌及其法人,并给出对应的公司的平均营收占比 +给出注册资本不超过一亿的的品牌及其法人,并给出对应的公司的总营收占比 +给出注册资本少于一亿的品牌及其法人,并给出对应的公司的最大营收占比 +给出注册资本不超过1亿的品牌及其法人,并找出对应的公司的最大营收占比 +在注册资本超过一亿的公司中,给出公司各品牌收入品牌的平均营收占比正好50%的品牌及其法人 +在注册资本不超过一亿的公司中,给出个品牌收入排名五的品牌及其法人 +在各品牌收入在公司排名中,当品牌的注册资本不大于1亿时,给出公司各品牌收入排名的支出占比的平均值小于等于45%的那些品牌的名称以及公司各品牌收入排名的营收占比的平均值 +在各品牌收入在公司排名中,当品牌的注册资本小于1亿时,给出公司各品牌收入排名数量小于等于5的那些品牌的名称以及公司各品牌收入排名的营收占比的最小值 +在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的平均值 +在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的最小值 +在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的总和 +在各品牌收入在公司排名中,给出每个品牌的名称,品牌的法定代表人,以及公司各品牌收入排名的营收占比的最大值 +在各品牌收入在公司排名中,给出收入排名不超过5的品牌及其法人 +在各品牌收入在公司排名中,给出在收入排名中的平均营收占比超过50%的品牌及其法人 +在各品牌收入在公司排名中,当公司各品牌收入排名的利润占比的平均值小于等于60%时,给出品牌的名称以及公司各品牌收入排名的营收占比的平均值 +在各品牌收入在公司排名中,当公司各品牌收入排名数量等于5时,给出品牌的名称以及公司各品牌收入排名的营收占比的最小值 +哪个品牌收入的平均利润占比最大,给出品牌的法定代表人,以及其收入平均营收占比 +哪3个品牌的收入最多,给出品牌的法定代表人,以及其收入总营收占比 +在公司各品牌收入排名的利润占比由少到多排列,给出对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比 +在公司各品牌收入排名的支出占比最少时,给出排名前3对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比 +在公司各品牌收入排名的支出占比最多时,给出排名前3对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比公司各品牌收入排名的利润占比 +在公司各品牌收入排名的利润占比最多时,给出排名前3对应的品牌的名称公司的名称以及公司各品牌收入排名的营收占比 +哪些法人的品牌的不在注册资本少于100万中,这些法定代表人的品牌都是哪些? +给出每一个品牌和其法人,所属的公司以及该公司的年营业额 +给出每一个品牌和注册时所用资本,以及所属的公司 +给出每一个品牌和注册时所用资本,所属的公司以及该公司的年营业额 +给出每一个品牌和注册时所用资本,所属的公司和总部所在地 +给出每一个品牌和其法人,所属的公司以及总部所在城市 +有自己品牌的公司有哪些?给出这些公司和总部所在地 +有自己品牌的公司有哪些?给出这些公司和年营业额 +在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司总部所在地点,并给出该品牌近几年的营收 +在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司年营业额,并给出该品牌近几年的营收 +在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司和公司法人,并给出该品牌近几年的营收 +在各公司其品牌的历年收入中,给出每一个品牌,其所属的公司,以及该品牌近几年的营收 +在品牌的注册资本至少1亿时,给出公司的名称以及公司的总部地点品牌的法定代表人 +在品牌的注册资本少于1亿时,给出公司的名称以及公司的总部地点品牌的法定代表人 +在品牌的注册资本超过1亿时,给出公司的名称以及公司的总部地点品牌的法定代表人 +在品牌的注册资本最多1亿时,给出公司的名称以及公司的年营业额品牌的法定代表人 +找到注册资本不止1亿,且年营业额不超过288亿的公司,以及给出总部地点,法人 +找出注册资本不止一亿,且年营业额低于288亿的公司,以及总部地点和法人 +找出注册资本不到1亿,且年营业额少于288亿的公司,以及给出总部地点和法人 +给出注册资本不超过1亿且年营业额不少于288亿的公司,总部地点和法人 +给出注册资本不低于一亿,且年营业额不止288亿的公司,以及总部地点和法人 +给出注册资本不超过一亿,且年营业额超过288亿的公司,总部在哪,法人是谁 +给出注册资本超过一亿且年营业额不止288亿的公司,以及总部在哪,法人是谁 +给出注册资本不超过一亿,且年营业额少于288亿的公司,以及总部地点和法人 +找出注册资本少于一亿,且年营业额不超过288亿的公司,以及总部在哪,法人是谁 +在公司品牌历年收入的利润最多500亿时,给出公司的名称品牌的名称并给出公司品牌历年收入的营收 +在公司品牌历年收入的营收同比增长至少100%时,给出公司的名称品牌的名称并给出公司品牌历年收入的营收 +年营业额不小于288亿的公司中,哪5个公司的平均营收占比最少?,并给出它们的总部地点 +年营业额不小于288亿的公司中,哪个公司的平均营收占比最大?并给出它的总部地点 +注册资本大于1亿的品牌中,哪个品牌历年收入的平均营收最大?并给出它的法定代表人 +注册资本不大于1亿的品牌中,哪个品牌历年收入的平均营收最大?并给出它的法定代表人 +年营业额不小于288亿的公司中,哪5个公司品牌的平均注册资本最多?并给出它们的总部地点 +年营业额不小于288亿的公司中,哪5个公司品牌的平均注册资本总共最少?并给出它们的总部地点 +找到年营业额不少于288亿的公司及总部地点,并给出对应的公司的最高营收占比 +给出年营业额少于288亿的公司及总部地点,并给出对应的公司的总营收占比 +给出年营业额不止288亿的公司及总部地点,并给出对应的公司的最大营收占比 +找出年营业额不少于288亿的公司及总部地点,并给出对应的公司的总营收占比 +请找出年营业额不超过288亿的公司及总部地点,并给出对应的公司的平均营收占比 +找到品牌的注册资本不大于1亿品牌的法定代表人并给出公司品牌历年收入的营收的平均值 +找到品牌的注册资本不小于1亿品牌的法定代表人并给出公司品牌历年收入的营收的平均值 +找到品牌的注册资本小于1亿品牌的法定代表人并给出公司品牌历年收入的营收的最小值 +找到品牌的注册资本不大于1亿品牌的法定代表人并给出公司品牌历年收入的营收的最大值 +找到品牌的注册资本小于1亿品牌的法定代表人并给出公司品牌历年收入的营收的平均值 +给出年营业额超过288亿的公司及总部地点,并给出对应的品牌中的最小注册资本 +给出年营业额不超过288亿的公司及总部地点,并给出对应的品牌中的最小注册资本 +给出不超过288亿年营业额的公司及总部地点,并给出这些品牌中的平均注册资本 +给出不超过288亿年营业额的公司及其总部地点,并给出这些品牌的的最大注册资本 +给出年营业额超过288亿的公司及其总部地点,并给出这些品牌的的总注册资本 +给出年营业额不超过288亿的各公司品牌中,给出收入排名中的营收占比加起来不超过50%的公司及其总部地点 +给出年营业额低于288亿的各公司的各品牌中,给出收入排名中的总营收占比不超过50%的公司及总部地点中 +在年营业额不少于288亿的公司中,给出品牌不超过5个的公司及其总部地点 +在年营业额不超过288亿的公司中,给出品牌少于5个的公司及其总部地点 +在各公司的各品牌收入排名中,当公司的年营业额大于288亿时,给出公司各品牌收入排名数量大于5的那些公司的名称以及公司各品牌收入排名的营收占比的最小值 +在各公司的各品牌收入排名中,当公司的年营业额小于288亿时,给出公司各品牌收入排名的利润占比的平均值大于75%的那些公司的名称以及公司各品牌收入排名的营收占比的最大值 +在各品牌的历年收入中,当品牌的注册资本不大于1亿时,给出公司品牌历年收入的利润同比增长的总和小于等于1000000的那些品牌的名称以及公司品牌历年收入的营收的最大值 +在各品牌的历年收入中,当品牌的注册资本小于1亿时,给出公司品牌历年收入数量小于5的那些品牌的名称以及公司品牌历年收入的营收的最小值 +在各品牌所属的公司中,当公司的年营业额大于288亿时,给出品牌数量小于5的那些公司的名称以及品牌的注册资本的总和 +在各品牌所属的公司中,当公司的年营业额小于288亿时,给出品牌数量大于等于5的那些公司的名称以及品牌的注册资本的总和 +在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的最大营收占比 +在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的最小营收占比 +在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的总营收占比 +在各公司每个品牌的收入排名中,给出每个公司,其总部地点,以及各品牌的平均营收占比 +在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的平均值 +在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的最小值 +在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的总和 +在各品牌的历年收入中,给出每个品牌的名称,品牌的法定代表人,以及公司品牌历年收入的营收的最大值 +在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的总和 +在各品牌所属的公司中,给出每个公司的名称,公司的总部地点,以及品牌的注册资本的最大值 diff --git a/evaluation/data/tables_dusql.json b/evaluation/data/tables_dusql.json new file mode 100644 index 000000000..d9640f0b3 --- /dev/null +++ b/evaluation/data/tables_dusql.json @@ -0,0 +1,758 @@ +[ + { + "db_id": "internet", + "table_names": [ + "company", + "brand", + "company_brand_revenue", + "company_revenue" + ], + "column_names": [ + [ + -1, + "*" + ], + [ + 0, + "company_id" + ], + [ + 0, + "company_name" + ], + [ + 0, + "headquarter_address" + ], + [ + 0, + "company_established_time" + ], + [ + 0, + "founder" + ], + [ + 0, + "ceo" + ], + [ + 0, + "annual_turnover" + ], + [ + 0, + "employee_count" + ], + [ + 1, + "brand_id" + ], + [ + 1, + "brand_name" + ], + [ + 1, + "brand_established_time" + ], + [ + 1, + "company_id" + ], + [ + 1, + "legal_representative" + ], + [ + 1, + "registered_capital" + ], + [ + 2, + "year_time" + ], + [ + 2, + "brand_id" + ], + [ + 2, + "revenue" + ], + [ + 2, + "profit" + ], + [ + 2, + "revenue_growth_year_on_year" + ], + [ + 2, + "profit_growth_year_on_year" + ], + [ + 3, + "company_id" + ], + [ + 3, + "brand_id" + ], + [ + 3, + "revenue_proportion" + ], + [ + 3, + "profit_proportion" + ], + [ + 3, + "expenditure_proportion" + ] + ], + "table_names_original": [ + "company", + "brand", + "company_brand_revenue", + "company_revenue" + ], + "column_names_original": [ + [ + -1, + "*" + ], + [ + 0, + "company_id" + ], + [ + 0, + "company_name" + ], + [ + 0, + "headquarter_address" + ], + [ + 0, + "company_established_time" + ], + [ + 0, + "founder" + ], + [ + 0, + "ceo" + ], + [ + 0, + "annual_turnover" + ], + [ + 0, + "employee_count" + ], + [ + 1, + "brand_id" + ], + [ + 1, + "brand_name" + ], + [ + 1, + "brand_established_time" + ], + [ + 1, + "company_id" + ], + [ + 1, + "legal_representative" + ], + [ + 1, + "registered_capital" + ], + [ + 2, + "year_time" + ], + [ + 2, + "brand_id" + ], + [ + 2, + "revenue" + ], + [ + 2, + "profit" + ], + [ + 2, + "revenue_growth_year_on_year" + ], + [ + 2, + "profit_growth_year_on_year" + ], + [ + 3, + "company_id" + ], + [ + 3, + "brand_id" + ], + [ + 3, + "revenue_proportion" + ], + [ + 3, + "profit_proportion" + ], + [ + 3, + "expenditure_proportion" + ] + ], + "column_types": [ + "text", + "number", + "text", + "text", + "time", + "text", + "time", + "number", + "number", + "number", + "text", + "time", + "text", + "text", + "number", + "time", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number" + ], + "foreign_keys": [ + [ + 12, + 1 + ], + [ + 21, + 1 + ], + [ + 22, + 9 + ], + [ + 16, + 9 + ] + ], + "primary_keys": [ + 1, + 9 + ] + }, + { + "db_id": "china_travel_agency", + "table_names": [ + "travel_agency", + "outbound_travel_routes", + "country_outbound_travel", + "domestic_travel_routes", + "cruise_route" + ], + "column_names": [ + [ + -1, + "*" + ], + [ + 0, + "travel_agency_id" + ], + [ + 0, + "travel_agency_name" + ], + [ + 0, + "travel_agency_level" + ], + [ + 0, + "number_countrie_outbound_travel" + ], + [ + 0, + "number_domestic_tourist_cities" + ], + [ + 0, + "number_outbound_travel_routes" + ], + [ + 0, + "number_domestic_travel_routes" + ], + [ + 0, + "asia_ranking" + ], + [ + 0, + "number_overseas_tourists_received" + ], + [ + 0, + "number_overseas_companies" + ], + [ + 0, + "number_holding_subsidiaries" + ], + [ + 0, + "number_traveling_salesmen_business_relationships" + ], + [ + 0, + "number_duty_free_shops" + ], + [ + 1, + "outbound_route_id" + ], + [ + 1, + "outbound_route_name" + ], + [ + 1, + "travel_agency_id" + ], + [ + 1, + "outbound_departure_city" + ], + [ + 1, + "outbound_days" + ], + [ + 1, + "adult_price" + ], + [ + 1, + "child_price" + ], + [ + 1, + "countries" + ], + [ + 1, + "attractions" + ], + [ + 1, + "total_ticket_price" + ], + [ + 2, + "outbound_travel_route_id" + ], + [ + 2, + "nation" + ], + [ + 2, + "travel_days" + ], + [ + 2, + "outbound_number_attractions" + ], + [ + 3, + "domestic_travel_route_id" + ], + [ + 3, + "domestic_travel_route_name" + ], + [ + 3, + "travel_agency_id" + ], + [ + 3, + "domestic_departure_city" + ], + [ + 3, + "domestic_days" + ], + [ + 3, + "presale_price" + ], + [ + 3, + "tour_price" + ], + [ + 3, + "number_people_group" + ], + [ + 3, + "personal_price" + ], + [ + 3, + "domestic_number_attractions" + ], + [ + 4, + "cruise_route_id" + ], + [ + 4, + "cruise_route_name" + ], + [ + 4, + "travel_agency_id" + ], + [ + 4, + "cruise_departure_city" + ], + [ + 4, + "cruise_days" + ], + [ + 4, + "interior_cabin_price" + ], + [ + 4, + "sea_view_room_price" + ], + [ + 4, + "balcony_room_price" + ], + [ + 4, + "sailing_area" + ], + [ + 4, + "cruise_line" + ] + ], + "table_names_original": [ + "travel_agency", + "outbound_travel_routes", + "country_outbound_travel", + "domestic_travel_routes", + "cruise_route" + ], + "column_names_original": [ + [ + -1, + "*" + ], + [ + 0, + "travel_agency_id" + ], + [ + 0, + "travel_agency_name" + ], + [ + 0, + "travel_agency_level" + ], + [ + 0, + "number_countrie_outbound_travel" + ], + [ + 0, + "number_domestic_tourist_cities" + ], + [ + 0, + "number_outbound_travel_routes" + ], + [ + 0, + "number_domestic_travel_routes" + ], + [ + 0, + "asia_ranking" + ], + [ + 0, + "number_overseas_tourists_received" + ], + [ + 0, + "number_overseas_companies" + ], + [ + 0, + "number_holding_subsidiaries" + ], + [ + 0, + "number_traveling_salesmen_business_relationships" + ], + [ + 0, + "number_duty_free_shops" + ], + [ + 1, + "outbound_route_id" + ], + [ + 1, + "outbound_route_name" + ], + [ + 1, + "travel_agency_id" + ], + [ + 1, + "outbound_departure_city" + ], + [ + 1, + "outbound_days" + ], + [ + 1, + "adult_price" + ], + [ + 1, + "child_price" + ], + [ + 1, + "countries" + ], + [ + 1, + "attractions" + ], + [ + 1, + "total_ticket_price" + ], + [ + 2, + "outbound_travel_route_id" + ], + [ + 2, + "nation" + ], + [ + 2, + "travel_days" + ], + [ + 2, + "outbound_number_attractions" + ], + [ + 3, + "domestic_travel_route_id" + ], + [ + 3, + "domestic_travel_route_name" + ], + [ + 3, + "travel_agency_id" + ], + [ + 3, + "domestic_departure_city" + ], + [ + 3, + "domestic_days" + ], + [ + 3, + "presale_price" + ], + [ + 3, + "tour_price" + ], + [ + 3, + "number_people_group" + ], + [ + 3, + "personal_price" + ], + [ + 3, + "domestic_number_attractions" + ], + [ + 4, + "cruise_route_id" + ], + [ + 4, + "cruise_route_name" + ], + [ + 4, + "travel_agency_id" + ], + [ + 4, + "cruise_departure_city" + ], + [ + 4, + "cruise_days" + ], + [ + 4, + "interior_cabin_price" + ], + [ + 4, + "sea_view_room_price" + ], + [ + 4, + "balcony_room_price" + ], + [ + 4, + "sailing_area" + ], + [ + 4, + "cruise_line" + ] + ], + "column_types": [ + "text", + "number", + "text", + "text", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "text", + "number", + "text", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "text", + "number", + "number", + "number", + "text", + "number", + "text", + "number", + "number", + "number", + "number", + "number", + "number", + "number", + "text", + "number", + "text", + "number", + "number", + "number", + "number", + "text", + "text" + ], + "foreign_keys": [ + [ + 40, + 1 + ], + [ + 24, + 14 + ], + [ + 30, + 1 + ], + [ + 16, + 1 + ] + ], + "primary_keys": [ + 1, + 14, + 28, + 38 + ] + } +] \ No newline at end of file diff --git a/evaluation/evaluation.py b/evaluation/evaluation.py new file mode 100644 index 000000000..2fe956ffe --- /dev/null +++ b/evaluation/evaluation.py @@ -0,0 +1,922 @@ +################################ +# val: number(float)/string(str)/sql(dict) +# col_unit: (agg_id, col_id, isDistinct(bool)) +# val_unit: (unit_op, col_unit1, col_unit2) +# table_unit: (table_type, col_unit/sql) +# cond_unit: (not_op, op_id, val_unit, val1, val2) +# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] +# sql { +# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) +# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} +# 'where': condition +# 'groupBy': [col_unit1, col_unit2, ...] +# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) +# 'having': condition +# 'limit': None/limit value +# 'intersect': None/sql +# 'except': None/sql +# 'union': None/sql +# } +################################ + +from __future__ import print_function + +import sqlparse +import logging +import os, sys +import json +import sqlite3 +import traceback +import argparse +import yaml +import re + +from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql +from build_pred_result import read_query,get_pred_result +from build_tables import build_table + +# Flag to disable value evaluation +DISABLE_VALUE = True +# Flag to disable distinct in select evaluation +DISABLE_DISTINCT = True + + +CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +JOIN_KEYWORDS = ('join', 'on', 'as') + +WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') +UNIT_OPS = ('none', '-', '+', "*", '/') +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') +TABLE_TYPE = { + 'sql': "sql", + 'table_unit': "table_unit", +} + +COND_OPS = ('and', 'or') +SQL_OPS = ('intersect', 'union', 'except') +ORDER_OPS = ('desc', 'asc') + + +HARDNESS = { + "component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'), + "component2": ('except', 'union', 'intersect') +} + + +def condition_has_or(conds): + return 'or' in conds[1::2] + + +def condition_has_like(conds): + return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]] + + +def condition_has_sql(conds): + for cond_unit in conds[::2]: + val1, val2 = cond_unit[3], cond_unit[4] + if val1 is not None and type(val1) is dict: + return True + if val2 is not None and type(val2) is dict: + return True + return False + + +def val_has_op(val_unit): + return val_unit[0] != UNIT_OPS.index('none') + + +def has_agg(unit): + return unit[0] != AGG_OPS.index('none') + + +def accuracy(count, total): + if count == total: + return 1 + return 0 + + +def recall(count, total): + if count == total: + return 1 + return 0 + + +def F1(acc, rec): + if (acc + rec) == 0: + return 0 + return (2. * acc * rec) / (acc + rec) + + +def get_scores(count, pred_total, label_total): + if pred_total != label_total: + return 0,0,0 + elif count == pred_total: + return 1,1,1 + return 0,0,0 + + +def eval_sel(pred, label): + pred_sel = pred['select'][1] + label_sel = label['select'][1] + label_wo_agg = [unit[1] for unit in label_sel] + pred_total = len(pred_sel) + label_total = len(label_sel) + cnt = 0 + cnt_wo_agg = 0 + + for unit in pred_sel: + if unit in label_sel: + cnt += 1 + label_sel.remove(unit) + if unit[1] in label_wo_agg: + cnt_wo_agg += 1 + label_wo_agg.remove(unit[1]) + + return label_total, pred_total, cnt, cnt_wo_agg + + +def eval_where(pred, label): + pred_conds = [unit for unit in pred['where'][::2]] + label_conds = [unit for unit in label['where'][::2]] + label_wo_agg = [unit[2] for unit in label_conds] + pred_total = len(pred_conds) + label_total = len(label_conds) + cnt = 0 + cnt_wo_agg = 0 + + for unit in pred_conds: + if unit in label_conds: + cnt += 1 + label_conds.remove(unit) + if unit[2] in label_wo_agg: + cnt_wo_agg += 1 + label_wo_agg.remove(unit[2]) + + return label_total, pred_total, cnt, cnt_wo_agg + + +def eval_group(pred, label): + pred_cols = [unit[1] for unit in pred['groupBy']] + label_cols = [unit[1] for unit in label['groupBy']] + pred_total = len(pred_cols) + label_total = len(label_cols) + cnt = 0 + pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols] + label_cols = [label.split(".")[1] if "." in label else label for label in label_cols] + for col in pred_cols: + if col in label_cols: + cnt += 1 + label_cols.remove(col) + return label_total, pred_total, cnt + + +def eval_having(pred, label): + pred_total = label_total = cnt = 0 + if len(pred['groupBy']) > 0: + pred_total = 1 + if len(label['groupBy']) > 0: + label_total = 1 + + pred_cols = [unit[1] for unit in pred['groupBy']] + label_cols = [unit[1] for unit in label['groupBy']] + if pred_total == label_total == 1 \ + and pred_cols == label_cols \ + and pred['having'] == label['having']: + cnt = 1 + + return label_total, pred_total, cnt + + +def eval_order(pred, label): + pred_total = label_total = cnt = 0 + if len(pred['orderBy']) > 0: + pred_total = 1 + if len(label['orderBy']) > 0: + label_total = 1 + if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \ + ((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)): + cnt = 1 + return label_total, pred_total, cnt + + +def eval_and_or(pred, label): + pred_ao = pred['where'][1::2] + label_ao = label['where'][1::2] + pred_ao = set(pred_ao) + label_ao = set(label_ao) + + if pred_ao == label_ao: + return 1,1,1 + return len(pred_ao),len(label_ao),0 + + +def get_nestedSQL(sql): + nested = [] + for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]: + if type(cond_unit[3]) is dict: + nested.append(cond_unit[3]) + if type(cond_unit[4]) is dict: + nested.append(cond_unit[4]) + if sql['intersect'] is not None: + nested.append(sql['intersect']) + if sql['except'] is not None: + nested.append(sql['except']) + if sql['union'] is not None: + nested.append(sql['union']) + return nested + + +def eval_nested(pred, label): + label_total = 0 + pred_total = 0 + cnt = 0 + if pred is not None: + pred_total += 1 + if label is not None: + label_total += 1 + if pred is not None and label is not None: + cnt += Evaluator().eval_exact_match(pred, label) + return label_total, pred_total, cnt + + +def eval_IUEN(pred, label): + lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect']) + lt2, pt2, cnt2 = eval_nested(pred['except'], label['except']) + lt3, pt3, cnt3 = eval_nested(pred['union'], label['union']) + label_total = lt1 + lt2 + lt3 + pred_total = pt1 + pt2 + pt3 + cnt = cnt1 + cnt2 + cnt3 + return label_total, pred_total, cnt + + +def get_keywords(sql): + res = set() + if len(sql['where']) > 0: + res.add('where') + if len(sql['groupBy']) > 0: + res.add('group') + if len(sql['having']) > 0: + res.add('having') + if len(sql['orderBy']) > 0: + res.add(sql['orderBy'][0]) + res.add('order') + if sql['limit'] is not None: + res.add('limit') + if sql['except'] is not None: + res.add('except') + if sql['union'] is not None: + res.add('union') + if sql['intersect'] is not None: + res.add('intersect') + + # or keyword + ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] + if len([token for token in ao if token == 'or']) > 0: + res.add('or') + + cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] + # not keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0: + res.add('not') + + # in keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0: + res.add('in') + + # like keyword + if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0: + res.add('like') + + return res + + +def eval_keywords(pred, label): + pred_keywords = get_keywords(pred) + label_keywords = get_keywords(label) + pred_total = len(pred_keywords) + label_total = len(label_keywords) + cnt = 0 + + for k in pred_keywords: + if k in label_keywords: + cnt += 1 + return label_total, pred_total, cnt + + +def count_agg(units): + return len([unit for unit in units if has_agg(unit)]) + + +def count_component1(sql): + count = 0 + if len(sql['where']) > 0: + count += 1 + if len(sql['groupBy']) > 0: + count += 1 + if len(sql['orderBy']) > 0: + count += 1 + if sql['limit'] is not None: + count += 1 + if len(sql['from']['table_units']) > 0: # JOIN + count += len(sql['from']['table_units']) - 1 + + ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2] + count += len([token for token in ao if token == 'or']) + cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2] + count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) + + return count + + +def count_component2(sql): + nested = get_nestedSQL(sql) + return len(nested) + + +def count_others(sql): + count = 0 + # number of aggregation + agg_count = count_agg(sql['select'][1]) + agg_count += count_agg(sql['where'][::2]) + agg_count += count_agg(sql['groupBy']) + if len(sql['orderBy']) > 0: + agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] + + [unit[2] for unit in sql['orderBy'][1] if unit[2]]) + agg_count += count_agg(sql['having']) + if agg_count > 1: + count += 1 + + # number of select columns + if len(sql['select'][1]) > 1: + count += 1 + + # number of where conditions + if len(sql['where']) > 1: + count += 1 + + # number of group by clauses + if len(sql['groupBy']) > 1: + count += 1 + + return count + + +class Evaluator: + """A simple evaluator""" + def __init__(self): + self.partial_scores = None + + def eval_hardness(self, sql): + count_comp1_ = count_component1(sql) + count_comp2_ = count_component2(sql) + count_others_ = count_others(sql) + + if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0: + return "easy" + elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \ + (count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0): + return "medium" + elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \ + (2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \ + (count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1): + return "hard" + else: + return "extra" + + def eval_exact_match(self, pred, label): + partial_scores = self.eval_partial_match(pred, label) + self.partial_scores = partial_scores + + for _, score in partial_scores.items(): + if score['f1'] != 1: + return 0 + if len(label['from']['table_units']) > 0: + label_tables = sorted(label['from']['table_units']) + pred_tables = sorted(pred['from']['table_units']) + return label_tables == pred_tables + return 1 + + def eval_partial_match(self, pred, label): + res = {} + + label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) + res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total) + res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_group(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_having(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_order(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_and_or(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_IUEN(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + label_total, pred_total, cnt = eval_keywords(pred, label) + acc, rec, f1 = get_scores(cnt, pred_total, label_total) + res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total} + + return res + + +def isValidSQL(sql, db): + conn = sqlite3.connect(db) + cursor = conn.cursor() + try: + cursor.execute(sql) + except: + return False + return True + + +def print_scores(scores, etype): + levels = ['easy', 'medium', 'hard', 'extra', 'all'] + partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', + 'group', 'order', 'and/or', 'IUEN', 'keywords'] + + print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels)) + counts = [scores[level]['count'] for level in levels] + print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts)) + + if etype in ["all", "exec"]: + print('===================== EXECUTION ACCURACY =====================') + this_scores = [scores[level]['exec'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores)) + + if etype in ["all", "match"]: + print('\n====================== EXACT MATCHING ACCURACY =====================') + exact_scores = [scores[level]['exact'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores)) + print('\n---------------------PARTIAL MATCHING ACCURACY----------------------') + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['acc'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + + print('---------------------- PARTIAL MATCHING RECALL ----------------------') + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['rec'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + + print('---------------------- PARTIAL MATCHING F1 --------------------------') + for type_ in partial_types: + this_scores = [scores[level]['partial'][type_]['f1'] for level in levels] + print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores)) + + +def evaluate(gold, predict, db_dir, etype, kmaps,query_path): + with open(gold) as f: + glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] + + with open(predict) as f: + plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0] + # plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")] + # glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")] + evaluator = Evaluator() + #print(plist) + levels = ['easy', 'medium', 'hard', 'extra', 'all'] + partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)', + 'group', 'order', 'and/or', 'IUEN', 'keywords'] + entries = [] + scores = {} + log_list=[] + for level in levels: + scores[level] = {'count': 0, 'partial': {}, 'exact': 0.} + scores[level]['exec'] = 0 + for type_ in partial_types: + scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0} + + eval_err_num = 0 + questions=read_query(query_path) + index=0 + for p, g in zip(plist, glist): + p_str = p[0] + g_str, db = g + db_name = db + # db = os.path.join(db_dir, db, db + ".sqlite") + db = os.path.join(db_dir,db + ".db") + print(db) + schema = Schema(get_schema(db)) + g_sql = get_sql(schema, g_str) + hardness = evaluator.eval_hardness(g_sql) + scores[hardness]['count'] += 1 + scores['all']['count'] += 1 + + p_sql = p_str + + if etype in ["all", "exec"]: + exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql) + if not exec_score: + element={} + element["query"]=questions[index] + element["gold_sql"]=g_str + element["pred_sql"]=p_str + log_list.append(element) + if exec_score: + scores[hardness]['exec'] += 1.0 + scores['all']['exec'] += 1.0 + + if etype in ["all", "match"]: + exact_score = evaluator.eval_exact_match(p_sql, g_sql) + partial_scores = evaluator.partial_scores + if exact_score == 0: + print("{} pred: {}".format(hardness,p_str)) + print("{} gold: {}".format(hardness,g_str)) + print("") + scores[hardness]['exact'] += exact_score + scores['all']['exact'] += exact_score + for type_ in partial_types: + if partial_scores[type_]['pred_total'] > 0: + scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores[hardness]['partial'][type_]['acc_count'] += 1 + if partial_scores[type_]['label_total'] > 0: + scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores[hardness]['partial'][type_]['rec_count'] += 1 + scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1'] + if partial_scores[type_]['pred_total'] > 0: + scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc'] + scores['all']['partial'][type_]['acc_count'] += 1 + if partial_scores[type_]['label_total'] > 0: + scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec'] + scores['all']['partial'][type_]['rec_count'] += 1 + scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1'] + + entries.append({ + 'predictSQL': p_str, + 'goldSQL': g_str, + 'hardness': hardness, + 'exact': exact_score, + 'partial': partial_scores + }) + index=index+1 + + for level in levels: + if scores[level]['count'] == 0: + continue + if etype in ["all", "exec"]: + scores[level]['exec'] /= scores[level]['count'] + + if etype in ["all", "match"]: + scores[level]['exact'] /= scores[level]['count'] + for type_ in partial_types: + if scores[level]['partial'][type_]['acc_count'] == 0: + scores[level]['partial'][type_]['acc'] = 0 + else: + scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \ + scores[level]['partial'][type_]['acc_count'] * 1.0 + if scores[level]['partial'][type_]['rec_count'] == 0: + scores[level]['partial'][type_]['rec'] = 0 + else: + scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \ + scores[level]['partial'][type_]['rec_count'] * 1.0 + if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0: + scores[level]['partial'][type_]['f1'] = 1 + else: + scores[level]['partial'][type_]['f1'] = \ + 2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / ( + scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc']) + + print_scores(scores, etype) + print(scores['all']['exec']) + current_directory = os.path.dirname(os.path.abspath(__file__)) + file_name=current_directory+"/eval.json" + json_exist=os.path.exists(file_name) + if json_exist: + os.remove(file_name) + with open(file_name, 'w') as json_file: + json.dump(log_list, json_file, indent=4, ensure_ascii=False) + + +def eval_exec_match(db, p_str, g_str, pred, gold): + """ + return 1 if the values between prediction and gold are matching + in the corresponding index. Currently not support multiple col_unit(pairs). + """ + conn = sqlite3.connect(db) + cursor = conn.cursor() + try: + cursor.execute(p_str) + columns_tuple = cursor.description + p_fields = [field_tuple[0] for field_tuple in columns_tuple] + for index in range(0,len(p_fields)): + p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower()) + p_res = cursor.fetchall() + except: + return False + + cursor.execute(g_str) + q_res = cursor.fetchall() + + def res_map(res, p_fields): + rmap = {} + for i in range(0,len(p_fields)): + if p_fields[i] != "sys_imp_date": + value_list= [r[i] for r in res] + value_list.sort() + rmap[p_fields[i]] =value_list + return rmap + + g_fields = parse_sql(g_str) + + #print("p_res_map:{}".format(res_map(p_res, p_fields))) + #print("q_res_map:{}".format(res_map(q_res, g_fields))) + return res_map(p_res, p_fields) == res_map(q_res, g_fields) + +def parse_sql(sql): + # 使用 sqlparse 库解析 SQL 查询语句 + parsed = sqlparse.parse(sql)[0] + + # 获取查询类型(SELECT、INSERT、UPDATE 或 DELETE) + query_type = parsed.get_type() + + # 获取查询目标(表名、字段列表、值列表等) + if query_type == 'SELECT': + target = parse_select(parsed) + else: + target = None + + return target + + +def parse_select(parsed): + # 获取字段列表 + fields = [] + for token in parsed.tokens: + # + if isinstance(token, sqlparse.sql.IdentifierList): + for identifier in token.get_identifiers(): + fields.append(identifier.value.replace("`", "") + .replace("T1.", "").replace("T2.", "") + .replace("T3.", "").replace("T4.", "") + .replace("T5.", "").replace("T6.", "")) + if(len(fields)): + break + return fields + +# Rebuild SQL functions for value evaluation +def rebuild_cond_unit_val(cond_unit): + if cond_unit is None or not DISABLE_VALUE: + return cond_unit + + not_op, op_id, val_unit, val1, val2 = cond_unit + if type(val1) is not dict: + val1 = None + else: + val1 = rebuild_sql_val(val1) + if type(val2) is not dict: + val2 = None + else: + val2 = rebuild_sql_val(val2) + return not_op, op_id, val_unit, val1, val2 + + +def rebuild_condition_val(condition): + if condition is None or not DISABLE_VALUE: + return condition + + res = [] + for idx, it in enumerate(condition): + if idx % 2 == 0: + res.append(rebuild_cond_unit_val(it)) + else: + res.append(it) + return res + + +def rebuild_sql_val(sql): + if sql is None or not DISABLE_VALUE: + return sql + + sql['from']['conds'] = rebuild_condition_val(sql['from']['conds']) + sql['having'] = rebuild_condition_val(sql['having']) + sql['where'] = rebuild_condition_val(sql['where']) + sql['intersect'] = rebuild_sql_val(sql['intersect']) + sql['except'] = rebuild_sql_val(sql['except']) + sql['union'] = rebuild_sql_val(sql['union']) + + return sql + + +# Rebuild SQL functions for foreign key evaluation +def build_valid_col_units(table_units, schema): + col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']] + prefixs = [col_id[:-2] for col_id in col_ids] + valid_col_units= [] + for value in schema.idMap.values(): + if '.' in value and value[:value.index('.')] in prefixs: + valid_col_units.append(value) + return valid_col_units + + +def rebuild_col_unit_col(valid_col_units, col_unit, kmap): + if col_unit is None: + return col_unit + + agg_id, col_id, distinct = col_unit + if col_id in kmap and col_id in valid_col_units: + col_id = kmap[col_id] + if DISABLE_DISTINCT: + distinct = None + return agg_id, col_id, distinct + + +def rebuild_val_unit_col(valid_col_units, val_unit, kmap): + if val_unit is None: + return val_unit + + unit_op, col_unit1, col_unit2 = val_unit + col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap) + col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap) + return unit_op, col_unit1, col_unit2 + + +def rebuild_table_unit_col(valid_col_units, table_unit, kmap): + if table_unit is None: + return table_unit + + table_type, col_unit_or_sql = table_unit + if isinstance(col_unit_or_sql, tuple): + col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap) + return table_type, col_unit_or_sql + + +def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap): + if cond_unit is None: + return cond_unit + + not_op, op_id, val_unit, val1, val2 = cond_unit + val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap) + return not_op, op_id, val_unit, val1, val2 + + +def rebuild_condition_col(valid_col_units, condition, kmap): + for idx in range(len(condition)): + if idx % 2 == 0: + condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap) + return condition + + +def rebuild_select_col(valid_col_units, sel, kmap): + if sel is None: + return sel + distinct, _list = sel + new_list = [] + for it in _list: + agg_id, val_unit = it + new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap))) + if DISABLE_DISTINCT: + distinct = None + return distinct, new_list + + +def rebuild_from_col(valid_col_units, from_, kmap): + if from_ is None: + return from_ + + from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']] + from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap) + return from_ + + +def rebuild_group_by_col(valid_col_units, group_by, kmap): + if group_by is None: + return group_by + + return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by] + + +def rebuild_order_by_col(valid_col_units, order_by, kmap): + if order_by is None or len(order_by) == 0: + return order_by + + direction, val_units = order_by + new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units] + return direction, new_val_units + + +def rebuild_sql_col(valid_col_units, sql, kmap): + if sql is None: + return sql + + sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap) + sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap) + sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap) + sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap) + sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap) + sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap) + sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap) + sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap) + sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap) + + return sql + + +def build_foreign_key_map(entry): + cols_orig = entry["column_names_original"] + tables_orig = entry["table_names_original"] + + # rebuild cols corresponding to idmap in Schema + cols = [] + for col_orig in cols_orig: + if col_orig[0] >= 0: + t = tables_orig[col_orig[0]] + c = col_orig[1] + cols.append("__" + t.lower() + "." + c.lower() + "__") + else: + cols.append("__all__") + + def keyset_in_list(k1, k2, k_list): + for k_set in k_list: + if k1 in k_set or k2 in k_set: + return k_set + new_k_set = set() + k_list.append(new_k_set) + return new_k_set + + foreign_key_list = [] + foreign_keys = entry["foreign_keys"] + for fkey in foreign_keys: + key1, key2 = fkey + key_set = keyset_in_list(key1, key2, foreign_key_list) + key_set.add(key1) + key_set.add(key2) + + foreign_key_map = {} + for key_set in foreign_key_list: + sorted_list = sorted(list(key_set)) + midx = sorted_list[0] + for idx in sorted_list: + foreign_key_map[cols[idx]] = cols[midx] + + return foreign_key_map + + +def build_foreign_key_map_from_json(table): + with open(table) as f: + data = json.load(f) + tables = {} + for entry in data: + tables[entry['db_id']] = build_foreign_key_map(entry) + return tables + +def get_evaluation_result(): + current_directory = os.path.dirname(os.path.abspath(__file__)) + config_file=current_directory+"/config/config.yaml" + with open(config_file, 'r') as file: + config = yaml.safe_load(file) + db_dir=current_directory+"/data" + db_path=current_directory+"/data/" + db_file=db_path+config["domain"]+".db" + pred = current_directory+"/data/"+"pred_example_dusql.txt" + gold = current_directory+"/data/"+"gold_example_dusql.txt" + table= current_directory+"/data/"+"tables_dusql.json" + query_path=current_directory+"/data/"+config["domain"]+".txt" + etype="exec" + kmaps = build_foreign_key_map_from_json(table) + + evaluate(gold, pred, db_dir, etype, kmaps,query_path) + +def remove_unused_file(): + current_directory = os.path.dirname(os.path.abspath(__file__)) + config_file=current_directory+"/config/config.yaml" + with open(config_file, 'r') as file: + config = yaml.safe_load(file) + db_path=current_directory+"/data/" + db_file=db_path+config["domain"]+".db" + pred_file = current_directory+"/data/"+"pred_example_dusql.txt" + + db_exist=os.path.exists(db_file) + if db_exist: + os.remove(db_file) + print("db_file removed!") + pred_exist=os.path.exists(pred_file) + if pred_exist: + os.remove(pred_file) + print("pred_file removed!") + +if __name__ == "__main__": + build_table() + get_pred_result() + get_evaluation_result() + remove_unused_file() + + + diff --git a/evaluation/evaluation.sh b/evaluation/evaluation.sh new file mode 100644 index 000000000..0ffffe703 --- /dev/null +++ b/evaluation/evaluation.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +path=$(pwd) +echo ${path} + +python_path=${PYTHON_PATH:-"python3"} +pip_path=${PIP_PATH:-"pip3"} + +requirementPath=$path/requirements.txt +${pip_path} install -r ${requirementPath} +echo "install python modules success" +python $path/evaluation.py diff --git a/evaluation/process_sql.py b/evaluation/process_sql.py new file mode 100644 index 000000000..666560995 --- /dev/null +++ b/evaluation/process_sql.py @@ -0,0 +1,566 @@ +################################ +# Assumptions: +# 1. sql is correct +# 2. only table name has alias +# 3. only one intersect/union/except +# +# val: number(float)/string(str)/sql(dict) +# col_unit: (agg_id, col_id, isDistinct(bool)) +# val_unit: (unit_op, col_unit1, col_unit2) +# table_unit: (table_type, col_unit/sql) +# cond_unit: (not_op, op_id, val_unit, val1, val2) +# condition: [cond_unit1, 'and'/'or', cond_unit2, ...] +# sql { +# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...]) +# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition} +# 'where': condition +# 'groupBy': [col_unit1, col_unit2, ...] +# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...]) +# 'having': condition +# 'limit': None/limit value +# 'intersect': None/sql +# 'except': None/sql +# 'union': None/sql +# } +################################ + +import json +import sqlite3 +from nltk import word_tokenize + +CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except') +JOIN_KEYWORDS = ('join', 'on', 'as') + +WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists') +UNIT_OPS = ('none', '-', '+', "*", '/') +AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg') +TABLE_TYPE = { + 'sql': "sql", + 'table_unit': "table_unit", +} + +COND_OPS = ('and', 'or') +SQL_OPS = ('intersect', 'union', 'except') +ORDER_OPS = ('desc', 'asc') + + + +class Schema: + """ + Simple schema which maps table&column to a unique identifier + """ + def __init__(self, schema): + self._schema = schema + self._idMap = self._map(self._schema) + + @property + def schema(self): + return self._schema + + @property + def idMap(self): + return self._idMap + + def _map(self, schema): + idMap = {'*': "__all__"} + id = 1 + for key, vals in schema.items(): + for val in vals: + idMap[key.lower() + "." + val.lower()] = "__" + key.lower() + "." + val.lower() + "__" + id += 1 + + for key in schema: + idMap[key.lower()] = "__" + key.lower() + "__" + id += 1 + + return idMap + + +def get_schema(db): + """ + Get database's schema, which is a dict with table name as key + and list of column names as value + :param db: database path + :return: schema dict + """ + + schema = {} + conn = sqlite3.connect(db) + cursor = conn.cursor() + + # fetch table names + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = [str(table[0].lower()) for table in cursor.fetchall()] + # tables.remove("internet.company") + # tables.remove("internet.brand") + # tables.remove("internet.company_revenue") + # tables.remove("internet.company_brand_revenue") + + # fetch table info + for table in tables: + cursor.execute("PRAGMA table_info({})".format(table)) + schema[table] = [str(col[1].lower()) for col in cursor.fetchall()] + + return schema + + +def get_schema_from_json(fpath): + with open(fpath) as f: + data = json.load(f) + + schema = {} + for entry in data: + table = str(entry['table'].lower()) + cols = [str(col['column_name'].lower()) for col in entry['col_data']] + schema[table] = cols + + return schema + + +def tokenize(string): + string = str(string) + string = string.replace("\'", "\"") # ensures all string values wrapped by "" problem?? + quote_idxs = [idx for idx, char in enumerate(string) if char == '"'] + assert len(quote_idxs) % 2 == 0, "Unexpected quote" + + # keep string value as token + vals = {} + for i in range(len(quote_idxs)-1, -1, -2): + qidx1 = quote_idxs[i-1] + qidx2 = quote_idxs[i] + val = string[qidx1: qidx2+1] + key = "__val_{}_{}__".format(qidx1, qidx2) + string = string[:qidx1] + key + string[qidx2+1:] + vals[key] = val + + toks = [word.lower() for word in word_tokenize(string)] + # replace with string value token + for i in range(len(toks)): + if toks[i] in vals: + toks[i] = vals[toks[i]] + + # find if there exists !=, >=, <= + eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="] + eq_idxs.reverse() + prefix = ('!', '>', '<') + for eq_idx in eq_idxs: + pre_tok = toks[eq_idx-1] + if pre_tok in prefix: + toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ] + + return toks + + +def scan_alias(toks): + """Scan the index of 'as' and build the map for all alias""" + as_idxs = [idx for idx, tok in enumerate(toks) if tok == 'as'] + alias = {} + for idx in as_idxs: + alias[toks[idx+1]] = toks[idx-1] + return alias + + +def get_tables_with_alias(schema, toks): + tables = scan_alias(toks) + for key in schema: + assert key not in tables, "Alias {} has the same name in table".format(key) + tables[key] = key + return tables + + +def parse_col(toks, start_idx, tables_with_alias, schema, default_tables=None): + """ + :returns next idx, column id + """ + tok = toks[start_idx] + if tok == "*": + return start_idx + 1, schema.idMap[tok] + + if '.' in tok: # if token is a composite + alias, col = tok.split('.') + key = tables_with_alias[alias] + "." + col + return start_idx+1, schema.idMap[key] + + assert default_tables is not None and len(default_tables) > 0, "Default tables should not be None or empty" + + for alias in default_tables: + table = tables_with_alias[alias] + if tok in schema.schema[table]: + key = table + "." + tok + return start_idx+1, schema.idMap[key] + + assert False, "Error col: {}".format(tok) + + +def parse_col_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): + """ + :returns next idx, (agg_op id, col_id) + """ + idx = start_idx + len_ = len(toks) + isBlock = False + isDistinct = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] in AGG_OPS: + agg_id = AGG_OPS.index(toks[idx]) + idx += 1 + assert idx < len_ and toks[idx] == '(' + idx += 1 + if toks[idx] == "distinct": + idx += 1 + isDistinct = True + idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) + assert idx < len_ and toks[idx] == ')' + idx += 1 + return idx, (agg_id, col_id, isDistinct) + + if toks[idx] == "distinct": + idx += 1 + isDistinct = True + agg_id = AGG_OPS.index("none") + idx, col_id = parse_col(toks, idx, tables_with_alias, schema, default_tables) + + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + + return idx, (agg_id, col_id, isDistinct) + + +def parse_val_unit(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + col_unit1 = None + col_unit2 = None + unit_op = UNIT_OPS.index('none') + + idx, col_unit1 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + if idx < len_ and toks[idx] in UNIT_OPS: + unit_op = UNIT_OPS.index(toks[idx]) + idx += 1 + idx, col_unit2 = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + + return idx, (unit_op, col_unit1, col_unit2) + + +def parse_table_unit(toks, start_idx, tables_with_alias, schema): + """ + :returns next idx, table id, table name + """ + idx = start_idx + len_ = len(toks) + key = tables_with_alias[toks[idx]] + + if idx + 1 < len_ and toks[idx+1] == "as": + idx += 3 + else: + idx += 1 + + return idx, schema.idMap[key], key + + +def parse_value(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] == 'select': + idx, val = parse_sql(toks, idx, tables_with_alias, schema) + elif "\"" in toks[idx]: # token is a string value + val = toks[idx] + idx += 1 + else: + try: + val = float(toks[idx]) + idx += 1 + except: + end_idx = idx + while end_idx < len_ and toks[end_idx] != ',' and toks[end_idx] != ')'\ + and toks[end_idx] != 'and' and toks[end_idx] not in CLAUSE_KEYWORDS and toks[end_idx] not in JOIN_KEYWORDS: + end_idx += 1 + + idx, val = parse_col_unit(toks[start_idx: end_idx], 0, tables_with_alias, schema, default_tables) + idx = end_idx + + if isBlock: + assert toks[idx] == ')' + idx += 1 + + return idx, val + + +def parse_condition(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + conds = [] + + while idx < len_: + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + not_op = False + if toks[idx] == 'not': + not_op = True + idx += 1 + + assert idx < len_ and toks[idx] in WHERE_OPS, "Error condition: idx: {}, tok: {}".format(idx, toks[idx]) + op_id = WHERE_OPS.index(toks[idx]) + idx += 1 + val1 = val2 = None + if op_id == WHERE_OPS.index('between'): # between..and... special case: dual values + idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + assert toks[idx] == 'and' + idx += 1 + idx, val2 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + else: # normal case: single value + idx, val1 = parse_value(toks, idx, tables_with_alias, schema, default_tables) + val2 = None + + conds.append((not_op, op_id, val_unit, val1, val2)) + + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";") or toks[idx] in JOIN_KEYWORDS): + break + + if idx < len_ and toks[idx] in COND_OPS: + conds.append(toks[idx]) + idx += 1 # skip and/or + + return idx, conds + + +def parse_select(toks, start_idx, tables_with_alias, schema, default_tables=None): + idx = start_idx + len_ = len(toks) + + assert toks[idx] == 'select', "'select' not found" + idx += 1 + isDistinct = False + if idx < len_ and toks[idx] == 'distinct': + idx += 1 + isDistinct = True + val_units = [] + + while idx < len_ and toks[idx] not in CLAUSE_KEYWORDS: + agg_id = AGG_OPS.index("none") + if toks[idx] in AGG_OPS: + agg_id = AGG_OPS.index(toks[idx]) + idx += 1 + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + val_units.append((agg_id, val_unit)) + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + + return idx, (isDistinct, val_units) + + +def parse_from(toks, start_idx, tables_with_alias, schema): + """ + Assume in the from clause, all table units are combined with join + """ + assert 'from' in toks[start_idx:], "'from' not found" + + len_ = len(toks) + idx = toks.index('from', start_idx) + 1 + default_tables = [] + table_units = [] + conds = [] + + while idx < len_: + isBlock = False + if toks[idx] == '(': + isBlock = True + idx += 1 + + if toks[idx] == 'select': + idx, sql = parse_sql(toks, idx, tables_with_alias, schema) + table_units.append((TABLE_TYPE['sql'], sql)) + else: + if idx < len_ and toks[idx] == 'join': + idx += 1 # skip join + idx, table_unit, table_name = parse_table_unit(toks, idx, tables_with_alias, schema) + table_units.append((TABLE_TYPE['table_unit'],table_unit)) + default_tables.append(table_name) + if idx < len_ and toks[idx] == "on": + idx += 1 # skip on + idx, this_conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + if len(conds) > 0: + conds.append('and') + conds.extend(this_conds) + + if isBlock: + assert toks[idx] == ')' + idx += 1 + if idx < len_ and (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + break + + return idx, table_units, conds, default_tables + + +def parse_where(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + + if idx >= len_ or toks[idx] != 'where': + return idx, [] + + idx += 1 + idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + return idx, conds + + +def parse_group_by(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + col_units = [] + + if idx >= len_ or toks[idx] != 'group': + return idx, col_units + + idx += 1 + assert toks[idx] == 'by' + idx += 1 + + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + idx, col_unit = parse_col_unit(toks, idx, tables_with_alias, schema, default_tables) + col_units.append(col_unit) + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + else: + break + + return idx, col_units + + +def parse_order_by(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + val_units = [] + order_type = 'asc' # default type is 'asc' + + if idx >= len_ or toks[idx] != 'order': + return idx, val_units + + idx += 1 + assert toks[idx] == 'by' + idx += 1 + + while idx < len_ and not (toks[idx] in CLAUSE_KEYWORDS or toks[idx] in (")", ";")): + idx, val_unit = parse_val_unit(toks, idx, tables_with_alias, schema, default_tables) + val_units.append(val_unit) + if idx < len_ and toks[idx] in ORDER_OPS: + order_type = toks[idx] + idx += 1 + if idx < len_ and toks[idx] == ',': + idx += 1 # skip ',' + else: + break + + return idx, (order_type, val_units) + + +def parse_having(toks, start_idx, tables_with_alias, schema, default_tables): + idx = start_idx + len_ = len(toks) + + if idx >= len_ or toks[idx] != 'having': + return idx, [] + + idx += 1 + idx, conds = parse_condition(toks, idx, tables_with_alias, schema, default_tables) + return idx, conds + + +def parse_limit(toks, start_idx): + idx = start_idx + len_ = len(toks) + + if idx < len_ and toks[idx] == 'limit': + idx += 2 + return idx, int(toks[idx-1]) + + return idx, None + + +def parse_sql(toks, start_idx, tables_with_alias, schema): + isBlock = False # indicate whether this is a block of sql/sub-sql + len_ = len(toks) + idx = start_idx + + sql = {} + if toks[idx] == '(': + isBlock = True + idx += 1 + + # parse from clause in order to get default tables + from_end_idx, table_units, conds, default_tables = parse_from(toks, start_idx, tables_with_alias, schema) + sql['from'] = {'table_units': table_units, 'conds': conds} + # select clause + _, select_col_units = parse_select(toks, idx, tables_with_alias, schema, default_tables) + idx = from_end_idx + sql['select'] = select_col_units + # where clause + idx, where_conds = parse_where(toks, idx, tables_with_alias, schema, default_tables) + sql['where'] = where_conds + # group by clause + idx, group_col_units = parse_group_by(toks, idx, tables_with_alias, schema, default_tables) + sql['groupBy'] = group_col_units + # having clause + idx, having_conds = parse_having(toks, idx, tables_with_alias, schema, default_tables) + sql['having'] = having_conds + # order by clause + idx, order_col_units = parse_order_by(toks, idx, tables_with_alias, schema, default_tables) + sql['orderBy'] = order_col_units + # limit clause + idx, limit_val = parse_limit(toks, idx) + sql['limit'] = limit_val + + idx = skip_semicolon(toks, idx) + if isBlock: + assert toks[idx] == ')' + idx += 1 # skip ')' + idx = skip_semicolon(toks, idx) + + # intersect/union/except clause + for op in SQL_OPS: # initialize IUE + sql[op] = None + if idx < len_ and toks[idx] in SQL_OPS: + sql_op = toks[idx] + idx += 1 + idx, IUE_sql = parse_sql(toks, idx, tables_with_alias, schema) + sql[sql_op] = IUE_sql + return idx, sql + + +def load_data(fpath): + with open(fpath) as f: + data = json.load(f) + return data + + +def get_sql(schema, query): + toks = tokenize(query) + tables_with_alias = get_tables_with_alias(schema.schema, toks) + _, sql = parse_sql(toks, 0, tables_with_alias, schema) + #print(sql) + return sql + + +def skip_semicolon(toks, start_idx): + idx = start_idx + while idx < len(toks) and toks[idx] == ";": + idx += 1 + return idx diff --git a/evaluation/requirements.txt b/evaluation/requirements.txt new file mode 100644 index 000000000..c963161ac --- /dev/null +++ b/evaluation/requirements.txt @@ -0,0 +1,6 @@ +pysqlite3==0.5.2 +PyJWT==2.8.0 +PyYAML==6.0.1 +sqlparse==0.4.4 +nltk==3.8.1 + diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java index 4435234aa..b22c6d344 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/ChatDemoLoader.java @@ -91,6 +91,7 @@ public class ChatDemoLoader implements CommandLineRunner { addAgent1(); addAgent2(); addAgent3(); + addAgent4(); addSampleChats(); addSampleChats2(); updateQueryScore(1); @@ -508,6 +509,26 @@ public class ChatDemoLoader implements CommandLineRunner { agentService.createAgent(agent, User.getFakeUser()); } + private void addAgent4() { + Agent agent = new Agent(); + agent.setId(4); + agent.setName("DuSQL 互联网企业"); + agent.setDescription("DuSQL"); + agent.setStatus(1); + agent.setEnableSearch(1); + agent.setExamples(Lists.newArrayList()); + AgentConfig agentConfig = new AgentConfig(); + + LLMParserTool llmParserTool = new LLMParserTool(); + llmParserTool.setId("1"); + llmParserTool.setType(AgentToolType.NL2SQL_LLM); + llmParserTool.setModelIds(Lists.newArrayList(9L, 10L, 11L, 12L)); + agentConfig.getTools().add(llmParserTool); + + agent.setAgentConfig(JSONObject.toJSONString(agentConfig)); + agentService.createAgent(agent, User.getFakeUser()); + } + private void updateQueryScore(Integer queryId) { chatService.updateFeedback(queryId, 5, ""); } diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/DuSQLDemoDataLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/DuSQLDemoDataLoader.java new file mode 100644 index 000000000..7ab84333f --- /dev/null +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/DuSQLDemoDataLoader.java @@ -0,0 +1,292 @@ +package com.tencent.supersonic; + +import com.google.common.collect.Lists; +import com.tencent.supersonic.auth.api.authentication.pojo.User; +import com.tencent.supersonic.common.pojo.JoinCondition; +import com.tencent.supersonic.common.pojo.ModelRela; +import com.tencent.supersonic.common.pojo.enums.AggOperatorEnum; +import com.tencent.supersonic.common.pojo.enums.FilterOperatorEnum; +import com.tencent.supersonic.headless.api.pojo.enums.DimensionType; +import com.tencent.supersonic.headless.api.pojo.enums.IdentifyType; +import com.tencent.supersonic.headless.api.pojo.Dim; +import com.tencent.supersonic.headless.api.pojo.DimensionTimeTypeParams; +import com.tencent.supersonic.headless.api.pojo.Identify; +import com.tencent.supersonic.headless.api.pojo.Measure; +import com.tencent.supersonic.headless.api.pojo.ModelDetail; +import com.tencent.supersonic.headless.api.pojo.request.DomainReq; +import com.tencent.supersonic.headless.api.pojo.request.MetricReq; +import com.tencent.supersonic.headless.api.pojo.request.ModelReq; +import com.tencent.supersonic.headless.api.pojo.response.MetricResp; +import com.tencent.supersonic.headless.server.service.DomainService; +import com.tencent.supersonic.headless.server.service.MetricService; +import com.tencent.supersonic.headless.server.service.ModelRelaService; +import com.tencent.supersonic.headless.server.service.ModelService; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.BeanUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +@Component +@Slf4j +public class DuSQLDemoDataLoader { + + private User user = User.getFakeUser(); + + @Autowired + private DomainService domainService; + @Autowired + private ModelService modelService; + @Autowired + private ModelRelaService modelRelaService; + @Autowired + private MetricService metricService; + + public void doRun() { + try { + addDomain(); + addModel_1(); + addModel_2(); + addModel_3(); + addModel_4(); + addModelRela_1(); + addModelRela_2(); + addModelRela_3(); + addModelRela_4(); + } catch (Exception e) { + log.error("Failed to add bench mark demo data", e); + } + + } + + public void addDomain() { + DomainReq domainReq = new DomainReq(); + domainReq.setName("DuSQL_互联网企业"); + domainReq.setBizName("internet"); + domainReq.setParentId(0L); + domainReq.setViewers(Arrays.asList("admin", "tom", "jack")); + domainReq.setViewOrgs(Collections.singletonList("1")); + domainReq.setAdmins(Collections.singletonList("admin")); + domainReq.setAdminOrgs(Collections.emptyList()); + domainService.createDomain(domainReq, user); + } + + //9 + public void addModel_1() throws Exception { + ModelReq modelReq = new ModelReq(); + modelReq.setName("公司"); + modelReq.setBizName("company"); + modelReq.setDescription("公司"); + modelReq.setDatabaseId(1L); + modelReq.setDomainId(4L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("1")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + ModelDetail modelDetail = new ModelDetail(); + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); + DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none"); + dimension1.setTypeParams(dimensionTimeTypeParams); + dimensions.add(dimension1); + dimensions.add(new Dim("公司名称", "company_name", DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("总部地点", "headquarter_address", DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("公司成立时间", "company_established_time", DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("创始人", "founder", DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("首席执行官", "ceo", DimensionType.categorical.name(), 1)); + modelDetail.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("公司id", IdentifyType.primary.name(), "company_id")); + modelDetail.setIdentifiers(identifiers); + + List measures = new ArrayList<>(); + measures.add(new Measure("年营业额", "annual_turnover", AggOperatorEnum.SUM.name(), 1)); + Measure measure = new Measure("员工数", "employee_count", AggOperatorEnum.SUM.name(), 1); + measures.add(measure); + modelDetail.setMeasures(measures); + + modelDetail.setQueryType("sql_query"); + modelDetail.setSqlQuery("SELECT imp_date,company_id,company_name,headquarter_address," + + "company_established_time,founder,ceo,annual_turnover,employee_count FROM company"); + modelReq.setModelDetail(modelDetail); + modelService.createModel(modelReq, user); + } + + // 10 + public void addModel_2() throws Exception { + ModelReq modelReq = new ModelReq(); + modelReq.setName("品牌"); + modelReq.setBizName("brand"); + modelReq.setDescription("品牌"); + modelReq.setDatabaseId(1L); + modelReq.setDomainId(4L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("1")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + ModelDetail modelDetail = new ModelDetail(); + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); + DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none"); + dimension1.setTypeParams(dimensionTimeTypeParams); + dimensions.add(dimension1); + dimensions.add(new Dim("品牌名称", "brand_name", DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("品牌成立时间", "brand_established_time", DimensionType.categorical.name(), 1)); + dimensions.add(new Dim("法定代表人", "legal_representative", DimensionType.categorical.name(), 1)); + modelDetail.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("品牌id", IdentifyType.primary.name(), "brand_id")); + identifiers.add(new Identify("公司id", IdentifyType.foreign.name(), "company_id")); + modelDetail.setIdentifiers(identifiers); + + List measures = new ArrayList<>(); + measures.add(new Measure("注册资本", "registered_capital", AggOperatorEnum.SUM.name(), 1)); + modelDetail.setMeasures(measures); + + modelDetail.setQueryType("sql_query"); + modelDetail.setSqlQuery("SELECT imp_date,brand_id,brand_name,brand_established_time," + + "company_id,legal_representative,registered_capital FROM brand"); + modelReq.setModelDetail(modelDetail); + modelService.createModel(modelReq, user); + } + + // 11 + public void addModel_3() throws Exception { + ModelReq modelReq = new ModelReq(); + modelReq.setName("公司各品牌收入排名"); + modelReq.setBizName("company_revenue"); + modelReq.setDescription("公司各品牌收入排名"); + modelReq.setDatabaseId(1L); + modelReq.setDomainId(4L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("1")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + ModelDetail modelDetail = new ModelDetail(); + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); + DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none"); + dimension1.setTypeParams(dimensionTimeTypeParams); + dimensions.add(dimension1); + modelDetail.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("公司id", IdentifyType.foreign.name(), "company_id")); + identifiers.add(new Identify("品牌id", IdentifyType.foreign.name(), "brand_id")); + modelDetail.setIdentifiers(identifiers); + + List measures = new ArrayList<>(); + Measure measure = new Measure("营收占比", "revenue_proportion", AggOperatorEnum.SUM.name(), 1); + measures.add(measure); + measures.add(new Measure("利润占比", "profit_proportion", AggOperatorEnum.SUM.name(), 1)); + measures.add(new Measure("支出占比", "expenditure_proportion", AggOperatorEnum.SUM.name(), 1)); + modelDetail.setMeasures(measures); + + modelDetail.setQueryType("sql_query"); + modelDetail.setSqlQuery("SELECT imp_date,company_id,brand_id,revenue_proportion," + + "profit_proportion,expenditure_proportion FROM company_revenue"); + modelReq.setModelDetail(modelDetail); + modelService.createModel(modelReq, user); + MetricResp metricResp = metricService.getMetric(13L, user); + + MetricReq metricReq = new MetricReq(); + BeanUtils.copyProperties(metricResp, metricReq); + metricReq.setAlias("收入比例"); + metricService.updateMetric(metricReq, user); + } + + // 12 + public void addModel_4() throws Exception { + ModelReq modelReq = new ModelReq(); + modelReq.setName("公司品牌历年收入"); + modelReq.setBizName("company_brand_revenue"); + modelReq.setDescription("公司品牌历年收入"); + modelReq.setDatabaseId(1L); + modelReq.setDomainId(4L); + modelReq.setViewers(Arrays.asList("admin", "tom", "jack")); + modelReq.setViewOrgs(Collections.singletonList("1")); + modelReq.setAdmins(Collections.singletonList("admin")); + modelReq.setAdminOrgs(Collections.emptyList()); + ModelDetail modelDetail = new ModelDetail(); + List dimensions = new ArrayList<>(); + Dim dimension1 = new Dim("", "imp_date", DimensionType.time.name(), 0); + DimensionTimeTypeParams dimensionTimeTypeParams = new DimensionTimeTypeParams("false", "none"); + dimension1.setTypeParams(dimensionTimeTypeParams); + dimensions.add(dimension1); + dimensions.add(new Dim("年份", "year_time", DimensionType.categorical.name(), 1)); + modelDetail.setDimensions(dimensions); + + List identifiers = new ArrayList<>(); + identifiers.add(new Identify("品牌id", IdentifyType.foreign.name(), "brand_id")); + modelDetail.setIdentifiers(identifiers); + + List measures = new ArrayList<>(); + measures.add(new Measure("营收", "revenue", AggOperatorEnum.SUM.name(), 1)); + measures.add(new Measure("利润", "profit", AggOperatorEnum.SUM.name(), 1)); + measures.add(new Measure("营收同比增长", "revenue_growth_year_on_year", AggOperatorEnum.SUM.name(), 1)); + measures.add(new Measure("利润同比增长", "profit_growth_year_on_year", AggOperatorEnum.SUM.name(), 1)); + modelDetail.setMeasures(measures); + + modelDetail.setQueryType("sql_query"); + modelDetail.setSqlQuery("SELECT imp_date,year_time,brand_id,revenue,profit," + + "revenue_growth_year_on_year,profit_growth_year_on_year FROM company_brand_revenue"); + modelReq.setModelDetail(modelDetail); + modelService.createModel(modelReq, user); + + } + + public void addModelRela_1() { + List joinConditions = Lists.newArrayList(); + joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS)); + ModelRela modelRelaReq = new ModelRela(); + modelRelaReq.setDomainId(4L); + modelRelaReq.setFromModelId(9L); + modelRelaReq.setToModelId(10L); + modelRelaReq.setJoinType("inner join"); + modelRelaReq.setJoinConditions(joinConditions); + modelRelaService.save(modelRelaReq, user); + } + + public void addModelRela_2() { + List joinConditions = Lists.newArrayList(); + joinConditions.add(new JoinCondition("company_id", "company_id", FilterOperatorEnum.EQUALS)); + ModelRela modelRelaReq = new ModelRela(); + modelRelaReq.setDomainId(4L); + modelRelaReq.setFromModelId(9L); + modelRelaReq.setToModelId(11L); + modelRelaReq.setJoinType("inner join"); + modelRelaReq.setJoinConditions(joinConditions); + modelRelaService.save(modelRelaReq, user); + } + + public void addModelRela_3() { + List joinConditions = Lists.newArrayList(); + joinConditions.add(new JoinCondition("brand_id", "brand_id", FilterOperatorEnum.EQUALS)); + ModelRela modelRelaReq = new ModelRela(); + modelRelaReq.setDomainId(4L); + modelRelaReq.setFromModelId(10L); + modelRelaReq.setToModelId(11L); + modelRelaReq.setJoinType("inner join"); + modelRelaReq.setJoinConditions(joinConditions); + modelRelaService.save(modelRelaReq, user); + } + + public void addModelRela_4() { + List joinConditions = Lists.newArrayList(); + joinConditions.add(new JoinCondition("brand_id", "brand_id", FilterOperatorEnum.EQUALS)); + ModelRela modelRelaReq = new ModelRela(); + modelRelaReq.setDomainId(4L); + modelRelaReq.setFromModelId(10L); + modelRelaReq.setToModelId(12L); + modelRelaReq.setJoinType("inner join"); + modelRelaReq.setJoinConditions(joinConditions); + modelRelaService.save(modelRelaReq, user); + } + +} diff --git a/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java b/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java index a4d01d3fc..be4a6e136 100644 --- a/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java +++ b/launchers/standalone/src/main/java/com/tencent/supersonic/HeadlessDemoLoader.java @@ -25,6 +25,9 @@ public class HeadlessDemoLoader implements CommandLineRunner { @Autowired private BenchMarkDemoDataLoader benchMarkDemoLoader; + @Autowired + private DuSQLDemoDataLoader duSQLDemoDataLoader; + @Value("${demo.enabled:false}") private boolean demoEnabled; @@ -36,6 +39,7 @@ public class HeadlessDemoLoader implements CommandLineRunner { } modelDataDemoLoader.doRun(); benchMarkDemoLoader.doRun(); + duSQLDemoDataLoader.doRun(); isLoad = true; } @@ -50,4 +54,4 @@ public class HeadlessDemoLoader implements CommandLineRunner { return isLoad; } -} \ No newline at end of file +} diff --git a/launchers/standalone/src/main/resources/application-local.yaml b/launchers/standalone/src/main/resources/application-local.yaml index 4ea7e1171..815e408a2 100644 --- a/launchers/standalone/src/main/resources/application-local.yaml +++ b/launchers/standalone/src/main/resources/application-local.yaml @@ -91,4 +91,4 @@ inMemoryEmbeddingStore: query: optimizer: - enable: true \ No newline at end of file + enable: true diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_10_20.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_10_20.txt new file mode 100644 index 000000000..7bc20ae1d --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_10_20.txt @@ -0,0 +1,8 @@ +阿里云 _10_20 5 +天猫 _10_20 5 +腾讯游戏 _10_20 5 +度小满 _10_20 5 +京东金融 _10_20 5 + + + diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_10_22.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_10_22.txt new file mode 100644 index 000000000..4d9cccf5a --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_10_22.txt @@ -0,0 +1,8 @@ +张勇 _10_22 5 +马化腾 _10_22 5 +朱光 _10_22 5 +刘强东 _10_22 5 + + + + diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_15.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_15.txt new file mode 100644 index 000000000..5a2cb8f1d --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_15.txt @@ -0,0 +1,5 @@ +百度集团 _9_15 5 +阿里巴巴集团 _9_15 5 +深圳市腾讯计算机系统有限公司 _9_15 5 +北京京东世纪贸易有限公司 _9_15 5 +网易公司 _9_15 5 diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_16.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_16.txt new file mode 100644 index 000000000..c4e7f41d6 --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_16.txt @@ -0,0 +1,4 @@ +北京 _9_16 5 +杭州 _9_16 5 +深圳 _9_16 5 + diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_18.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_18.txt new file mode 100644 index 000000000..b15cc33d9 --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_18.txt @@ -0,0 +1,7 @@ +李彦宏 _9_18 5 +马云 _9_18 5 +马化腾 _9_18 5 +刘强东 _9_18 5 +丁磊 _9_18 5 + + diff --git a/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_19.txt b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_19.txt new file mode 100644 index 000000000..233cfe17f --- /dev/null +++ b/launchers/standalone/src/main/resources/data/dictionary/custom/DimValue_9_19.txt @@ -0,0 +1,7 @@ +李彦宏 _9_19 5 +张勇 _9_19 5 +刘炽平 _9_19 5 +刘强东 _9_19 5 +丁磊 _9_19 5 + + diff --git a/launchers/standalone/src/main/resources/db/data-h2.sql b/launchers/standalone/src/main/resources/db/data-h2.sql index 5860e3f0b..72e046101 100644 --- a/launchers/standalone/src/main/resources/db/data-h2.sql +++ b/launchers/standalone/src/main/resources/db/data-h2.sql @@ -1111,4 +1111,29 @@ MERGE INTO song(imp_date,song_name,artist_name,country,f_id,g_name,rating,langua MERGE INTO song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'打败它','Michel','英国',5,'流行',8,'英文','17-MAR-2002',720); MERGE INTO song(imp_date,song_name,artist_name,country,f_id,g_name,rating,languages,releasedate,resolution) VALUES (DATEADD('DAY', 0, CURRENT_DATE()),'阿杰伊阿卡什','Topu','印度',6,'现代',10,'孟加拉语','27-MAR-2004',320); + +insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','百度集团','北京','2000','李彦宏','李彦宏',102300000000,40000); +insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_132','阿里巴巴集团','杭州','1999年','马云','张勇',376800000000,103699); +insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_133','深圳市腾讯计算机系统有限公司','深圳','1998','马化腾','刘炽平',321600000000,56310); +insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','北京京东世纪贸易有限公司','北京','1998','刘强东','刘强东',28800000000,179000); +insert into company(imp_date,company_id,company_name,headquarter_address,company_established_time,founder,ceo,annual_turnover,employee_count) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','网易公司','杭州','1997','丁磊','丁磊',67500000000,20000); + +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_136','阿里云','2009年9月10日','item_enterprise_13_134','张勇',50000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_137','天猫','2012年1月11日','item_enterprise_13_134','张勇',100000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_138','腾讯游戏','2003','item_enterprise_13_131','马化腾',50000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_139','度小满','2018','item_enterprise_13_132','朱光',100000000); +insert into brand(imp_date,brand_id,brand_name,brand_established_time,company_id,legal_representative,registered_capital) VALUES (DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_140','京东金融','2017','item_enterprise_13_134','刘强东',100000000); + +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','item_enterprise_13_139',10,10,30); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_134','item_enterprise_13_138',80,80,60); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','item_enterprise_13_139',80,80,60); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_131','item_enterprise_13_137',80,80,60); +insert into company_revenue(imp_date,company_id,brand_id,revenue_proportion,profit_proportion,expenditure_proportion) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'item_enterprise_13_135','item_enterprise_13_137',10,10,30); + +insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2018','item_enterprise_13_138',500000000,-300000000,10,-10); +insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2019','item_enterprise_13_136',100000000000,50000000000,100,50); +insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_137',100000000000,50000000000,100,-10); +insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES (DATEADD('DAY', -1, CURRENT_DATE()), '2018','item_enterprise_13_139',500000000,50000000000,10,50); +insert into company_brand_revenue(imp_date,year_time,brand_id,revenue,profit,revenue_growth_year_on_year,profit_growth_year_on_year) VALUES ( DATEADD('DAY', -1, CURRENT_DATE()),'2018','item_enterprise_13_138',100000000000,-300000000,10,50); + -- benchmark diff --git a/launchers/standalone/src/main/resources/db/schema-h2.sql b/launchers/standalone/src/main/resources/db/schema-h2.sql index 273ad467f..9b653cbcf 100644 --- a/launchers/standalone/src/main/resources/db/schema-h2.sql +++ b/launchers/standalone/src/main/resources/db/schema-h2.sql @@ -455,6 +455,51 @@ CREATE TABLE IF NOT EXISTS `song` ( ); COMMENT ON TABLE song IS 'song'; +CREATE TABLE IF NOT EXISTS `company` ( + `imp_date` varchar(50) , + `company_id` varchar(50) NOT NULL , + `company_name` varchar(50) NOT NULL , + `headquarter_address` varchar(50) NOT NULL , + `company_established_time` varchar(20) NOT NULL , + `founder` varchar(20) NOT NULL , + `ceo` varchar(20) NOT NULL , + `annual_turnover` bigint(15) , + `employee_count` int(7) , + PRIMARY KEY (`company_id`) + ); + +CREATE TABLE IF NOT EXISTS `brand` ( + `imp_date` varchar(50) , + `brand_id` varchar(50) NOT NULL , + `brand_name` varchar(50) NOT NULL , + `brand_established_time` varchar(20) NOT NULL , + `company_id` varchar(50) NOT NULL , + `legal_representative` varchar(20) NOT NULL , + `registered_capital` bigint(15) , + PRIMARY KEY (`brand_id`) + ); + +CREATE TABLE IF NOT EXISTS `company_revenue` ( + `imp_date` varchar(50) , + `company_id` varchar(50) NOT NULL , + `brand_id` varchar(50) NOT NULL , + `revenue_proportion` double NOT NULL, + `profit_proportion` double NOT NULL , + `expenditure_proportion` double NOT NULL + ); + +CREATE TABLE IF NOT EXISTS `company_brand_revenue` ( + `imp_date` varchar(50) , + `year_time` varchar(10) NOT NULL , + `brand_id` varchar(50) NOT NULL , + `revenue` bigint(15) NOT NULL, + `profit` bigint(15) NOT NULL , + `revenue_growth_year_on_year` double NOT NULL , + `profit_growth_year_on_year` double NOT NULL + ); + + + CREATE TABLE IF NOT EXISTS s2_sys_parameter ( id INT PRIMARY KEY AUTO_INCREMENT, @@ -498,4 +543,4 @@ CREATE TABLE IF NOT EXISTS `s2_app` ( created_by VARCHAR(255), updated_at TIMESTAMP, updated_by VARCHAR(255) -); \ No newline at end of file +); diff --git a/launchers/standalone/src/main/resources/hanlp.properties b/launchers/standalone/src/main/resources/hanlp.properties index 8faa512a4..729ad70ce 100644 --- a/launchers/standalone/src/main/resources/hanlp.properties +++ b/launchers/standalone/src/main/resources/hanlp.properties @@ -1,2 +1,2 @@ root=. -CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_1_3.txt;data/dictionary/custom/benchmark_cspider.txt; +CustomDictionaryPath=data/dictionary/custom/DimValue_1_1.txt;data/dictionary/custom/DimValue_1_2.txt;data/dictionary/custom/DimValue_9_15.txt;data/dictionary/custom/DimValue_9_16.txt;data/dictionary/custom/DimValue_9_18.txt;data/dictionary/custom/DimValue_9_19.txt;data/dictionary/custom/DimValue_10_20.txt;data/dictionary/custom/DimValue_10_22.txt;