[improvement][supersonic] add text-to-sql evaluation (#696)

* [improvement] llm supports all models

* [improvement] alias convert to SemanticParseInfo

* [improvement] support join

* [improvement] add evaluation.py

* [improvement] add text2sql_evalution.py

* [improvement] add text2sql_evalution.py

* [improvement] add evalution

* [improvement] add evalution

* [improvement] add evalution

---------

Co-authored-by: zuopengge <hwzuopengge@tencent.com>
This commit is contained in:
mainmain
2024-01-30 10:46:45 +08:00
committed by GitHub
parent aae3d6b297
commit c398ac1a84
29 changed files with 3347 additions and 15 deletions

922
evaluation/evaluation.py Normal file
View File

@@ -0,0 +1,922 @@
################################
# val: number(float)/string(str)/sql(dict)
# col_unit: (agg_id, col_id, isDistinct(bool))
# val_unit: (unit_op, col_unit1, col_unit2)
# table_unit: (table_type, col_unit/sql)
# cond_unit: (not_op, op_id, val_unit, val1, val2)
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
# sql {
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
# 'where': condition
# 'groupBy': [col_unit1, col_unit2, ...]
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
# 'having': condition
# 'limit': None/limit value
# 'intersect': None/sql
# 'except': None/sql
# 'union': None/sql
# }
################################
from __future__ import print_function
import sqlparse
import logging
import os, sys
import json
import sqlite3
import traceback
import argparse
import yaml
import re
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
from build_pred_result import read_query,get_pred_result
from build_tables import build_table
# Flag to disable value evaluation
DISABLE_VALUE = True
# Flag to disable distinct in select evaluation
DISABLE_DISTINCT = True
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
JOIN_KEYWORDS = ('join', 'on', 'as')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
UNIT_OPS = ('none', '-', '+', "*", '/')
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
TABLE_TYPE = {
'sql': "sql",
'table_unit': "table_unit",
}
COND_OPS = ('and', 'or')
SQL_OPS = ('intersect', 'union', 'except')
ORDER_OPS = ('desc', 'asc')
HARDNESS = {
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
"component2": ('except', 'union', 'intersect')
}
def condition_has_or(conds):
return 'or' in conds[1::2]
def condition_has_like(conds):
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
def condition_has_sql(conds):
for cond_unit in conds[::2]:
val1, val2 = cond_unit[3], cond_unit[4]
if val1 is not None and type(val1) is dict:
return True
if val2 is not None and type(val2) is dict:
return True
return False
def val_has_op(val_unit):
return val_unit[0] != UNIT_OPS.index('none')
def has_agg(unit):
return unit[0] != AGG_OPS.index('none')
def accuracy(count, total):
if count == total:
return 1
return 0
def recall(count, total):
if count == total:
return 1
return 0
def F1(acc, rec):
if (acc + rec) == 0:
return 0
return (2. * acc * rec) / (acc + rec)
def get_scores(count, pred_total, label_total):
if pred_total != label_total:
return 0,0,0
elif count == pred_total:
return 1,1,1
return 0,0,0
def eval_sel(pred, label):
pred_sel = pred['select'][1]
label_sel = label['select'][1]
label_wo_agg = [unit[1] for unit in label_sel]
pred_total = len(pred_sel)
label_total = len(label_sel)
cnt = 0
cnt_wo_agg = 0
for unit in pred_sel:
if unit in label_sel:
cnt += 1
label_sel.remove(unit)
if unit[1] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[1])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_where(pred, label):
pred_conds = [unit for unit in pred['where'][::2]]
label_conds = [unit for unit in label['where'][::2]]
label_wo_agg = [unit[2] for unit in label_conds]
pred_total = len(pred_conds)
label_total = len(label_conds)
cnt = 0
cnt_wo_agg = 0
for unit in pred_conds:
if unit in label_conds:
cnt += 1
label_conds.remove(unit)
if unit[2] in label_wo_agg:
cnt_wo_agg += 1
label_wo_agg.remove(unit[2])
return label_total, pred_total, cnt, cnt_wo_agg
def eval_group(pred, label):
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
pred_total = len(pred_cols)
label_total = len(label_cols)
cnt = 0
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
for col in pred_cols:
if col in label_cols:
cnt += 1
label_cols.remove(col)
return label_total, pred_total, cnt
def eval_having(pred, label):
pred_total = label_total = cnt = 0
if len(pred['groupBy']) > 0:
pred_total = 1
if len(label['groupBy']) > 0:
label_total = 1
pred_cols = [unit[1] for unit in pred['groupBy']]
label_cols = [unit[1] for unit in label['groupBy']]
if pred_total == label_total == 1 \
and pred_cols == label_cols \
and pred['having'] == label['having']:
cnt = 1
return label_total, pred_total, cnt
def eval_order(pred, label):
pred_total = label_total = cnt = 0
if len(pred['orderBy']) > 0:
pred_total = 1
if len(label['orderBy']) > 0:
label_total = 1
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
cnt = 1
return label_total, pred_total, cnt
def eval_and_or(pred, label):
pred_ao = pred['where'][1::2]
label_ao = label['where'][1::2]
pred_ao = set(pred_ao)
label_ao = set(label_ao)
if pred_ao == label_ao:
return 1,1,1
return len(pred_ao),len(label_ao),0
def get_nestedSQL(sql):
nested = []
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
if type(cond_unit[3]) is dict:
nested.append(cond_unit[3])
if type(cond_unit[4]) is dict:
nested.append(cond_unit[4])
if sql['intersect'] is not None:
nested.append(sql['intersect'])
if sql['except'] is not None:
nested.append(sql['except'])
if sql['union'] is not None:
nested.append(sql['union'])
return nested
def eval_nested(pred, label):
label_total = 0
pred_total = 0
cnt = 0
if pred is not None:
pred_total += 1
if label is not None:
label_total += 1
if pred is not None and label is not None:
cnt += Evaluator().eval_exact_match(pred, label)
return label_total, pred_total, cnt
def eval_IUEN(pred, label):
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
label_total = lt1 + lt2 + lt3
pred_total = pt1 + pt2 + pt3
cnt = cnt1 + cnt2 + cnt3
return label_total, pred_total, cnt
def get_keywords(sql):
res = set()
if len(sql['where']) > 0:
res.add('where')
if len(sql['groupBy']) > 0:
res.add('group')
if len(sql['having']) > 0:
res.add('having')
if len(sql['orderBy']) > 0:
res.add(sql['orderBy'][0])
res.add('order')
if sql['limit'] is not None:
res.add('limit')
if sql['except'] is not None:
res.add('except')
if sql['union'] is not None:
res.add('union')
if sql['intersect'] is not None:
res.add('intersect')
# or keyword
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
if len([token for token in ao if token == 'or']) > 0:
res.add('or')
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
# not keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
res.add('not')
# in keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
res.add('in')
# like keyword
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
res.add('like')
return res
def eval_keywords(pred, label):
pred_keywords = get_keywords(pred)
label_keywords = get_keywords(label)
pred_total = len(pred_keywords)
label_total = len(label_keywords)
cnt = 0
for k in pred_keywords:
if k in label_keywords:
cnt += 1
return label_total, pred_total, cnt
def count_agg(units):
return len([unit for unit in units if has_agg(unit)])
def count_component1(sql):
count = 0
if len(sql['where']) > 0:
count += 1
if len(sql['groupBy']) > 0:
count += 1
if len(sql['orderBy']) > 0:
count += 1
if sql['limit'] is not None:
count += 1
if len(sql['from']['table_units']) > 0: # JOIN
count += len(sql['from']['table_units']) - 1
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
count += len([token for token in ao if token == 'or'])
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
return count
def count_component2(sql):
nested = get_nestedSQL(sql)
return len(nested)
def count_others(sql):
count = 0
# number of aggregation
agg_count = count_agg(sql['select'][1])
agg_count += count_agg(sql['where'][::2])
agg_count += count_agg(sql['groupBy'])
if len(sql['orderBy']) > 0:
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
agg_count += count_agg(sql['having'])
if agg_count > 1:
count += 1
# number of select columns
if len(sql['select'][1]) > 1:
count += 1
# number of where conditions
if len(sql['where']) > 1:
count += 1
# number of group by clauses
if len(sql['groupBy']) > 1:
count += 1
return count
class Evaluator:
"""A simple evaluator"""
def __init__(self):
self.partial_scores = None
def eval_hardness(self, sql):
count_comp1_ = count_component1(sql)
count_comp2_ = count_component2(sql)
count_others_ = count_others(sql)
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
return "easy"
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
return "medium"
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
return "hard"
else:
return "extra"
def eval_exact_match(self, pred, label):
partial_scores = self.eval_partial_match(pred, label)
self.partial_scores = partial_scores
for _, score in partial_scores.items():
if score['f1'] != 1:
return 0
if len(label['from']['table_units']) > 0:
label_tables = sorted(label['from']['table_units'])
pred_tables = sorted(pred['from']['table_units'])
return label_tables == pred_tables
return 1
def eval_partial_match(self, pred, label):
res = {}
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_group(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_having(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_order(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_and_or(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_IUEN(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
label_total, pred_total, cnt = eval_keywords(pred, label)
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
return res
def isValidSQL(sql, db):
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(sql)
except:
return False
return True
def print_scores(scores, etype):
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
counts = [scores[level]['count'] for level in levels]
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
if etype in ["all", "exec"]:
print('===================== EXECUTION ACCURACY =====================')
this_scores = [scores[level]['exec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
if etype in ["all", "match"]:
print('\n====================== EXACT MATCHING ACCURACY =====================')
exact_scores = [scores[level]['exact'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
print('---------------------- PARTIAL MATCHING F1 --------------------------')
for type_ in partial_types:
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
with open(gold) as f:
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
with open(predict) as f:
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
evaluator = Evaluator()
#print(plist)
levels = ['easy', 'medium', 'hard', 'extra', 'all']
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
'group', 'order', 'and/or', 'IUEN', 'keywords']
entries = []
scores = {}
log_list=[]
for level in levels:
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
scores[level]['exec'] = 0
for type_ in partial_types:
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
eval_err_num = 0
questions=read_query(query_path)
index=0
for p, g in zip(plist, glist):
p_str = p[0]
g_str, db = g
db_name = db
# db = os.path.join(db_dir, db, db + ".sqlite")
db = os.path.join(db_dir,db + ".db")
print(db)
schema = Schema(get_schema(db))
g_sql = get_sql(schema, g_str)
hardness = evaluator.eval_hardness(g_sql)
scores[hardness]['count'] += 1
scores['all']['count'] += 1
p_sql = p_str
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if not exec_score:
element={}
element["query"]=questions[index]
element["gold_sql"]=g_str
element["pred_sql"]=p_str
log_list.append(element)
if exec_score:
scores[hardness]['exec'] += 1.0
scores['all']['exec'] += 1.0
if etype in ["all", "match"]:
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
partial_scores = evaluator.partial_scores
if exact_score == 0:
print("{} pred: {}".format(hardness,p_str))
print("{} gold: {}".format(hardness,g_str))
print("")
scores[hardness]['exact'] += exact_score
scores['all']['exact'] += exact_score
for type_ in partial_types:
if partial_scores[type_]['pred_total'] > 0:
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores[hardness]['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores[hardness]['partial'][type_]['rec_count'] += 1
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
if partial_scores[type_]['pred_total'] > 0:
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
scores['all']['partial'][type_]['acc_count'] += 1
if partial_scores[type_]['label_total'] > 0:
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
scores['all']['partial'][type_]['rec_count'] += 1
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
entries.append({
'predictSQL': p_str,
'goldSQL': g_str,
'hardness': hardness,
'exact': exact_score,
'partial': partial_scores
})
index=index+1
for level in levels:
if scores[level]['count'] == 0:
continue
if etype in ["all", "exec"]:
scores[level]['exec'] /= scores[level]['count']
if etype in ["all", "match"]:
scores[level]['exact'] /= scores[level]['count']
for type_ in partial_types:
if scores[level]['partial'][type_]['acc_count'] == 0:
scores[level]['partial'][type_]['acc'] = 0
else:
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
scores[level]['partial'][type_]['acc_count'] * 1.0
if scores[level]['partial'][type_]['rec_count'] == 0:
scores[level]['partial'][type_]['rec'] = 0
else:
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
scores[level]['partial'][type_]['rec_count'] * 1.0
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
scores[level]['partial'][type_]['f1'] = 1
else:
scores[level]['partial'][type_]['f1'] = \
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
print_scores(scores, etype)
print(scores['all']['exec'])
current_directory = os.path.dirname(os.path.abspath(__file__))
file_name=current_directory+"/eval.json"
json_exist=os.path.exists(file_name)
if json_exist:
os.remove(file_name)
with open(file_name, 'w') as json_file:
json.dump(log_list, json_file, indent=4, ensure_ascii=False)
def eval_exec_match(db, p_str, g_str, pred, gold):
"""
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
cursor.execute(p_str)
columns_tuple = cursor.description
p_fields = [field_tuple[0] for field_tuple in columns_tuple]
for index in range(0,len(p_fields)):
p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower())
p_res = cursor.fetchall()
except:
return False
cursor.execute(g_str)
q_res = cursor.fetchall()
def res_map(res, p_fields):
rmap = {}
for i in range(0,len(p_fields)):
if p_fields[i] != "sys_imp_date":
value_list= [r[i] for r in res]
value_list.sort()
rmap[p_fields[i]] =value_list
return rmap
g_fields = parse_sql(g_str)
#print("p_res_map:{}".format(res_map(p_res, p_fields)))
#print("q_res_map:{}".format(res_map(q_res, g_fields)))
return res_map(p_res, p_fields) == res_map(q_res, g_fields)
def parse_sql(sql):
# 使用 sqlparse 库解析 SQL 查询语句
parsed = sqlparse.parse(sql)[0]
# 获取查询类型SELECT、INSERT、UPDATE 或 DELETE
query_type = parsed.get_type()
# 获取查询目标(表名、字段列表、值列表等)
if query_type == 'SELECT':
target = parse_select(parsed)
else:
target = None
return target
def parse_select(parsed):
# 获取字段列表
fields = []
for token in parsed.tokens:
#
if isinstance(token, sqlparse.sql.IdentifierList):
for identifier in token.get_identifiers():
fields.append(identifier.value.replace("`", "")
.replace("T1.", "").replace("T2.", "")
.replace("T3.", "").replace("T4.", "")
.replace("T5.", "").replace("T6.", ""))
if(len(fields)):
break
return fields
# Rebuild SQL functions for value evaluation
def rebuild_cond_unit_val(cond_unit):
if cond_unit is None or not DISABLE_VALUE:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
if type(val1) is not dict:
val1 = None
else:
val1 = rebuild_sql_val(val1)
if type(val2) is not dict:
val2 = None
else:
val2 = rebuild_sql_val(val2)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_val(condition):
if condition is None or not DISABLE_VALUE:
return condition
res = []
for idx, it in enumerate(condition):
if idx % 2 == 0:
res.append(rebuild_cond_unit_val(it))
else:
res.append(it)
return res
def rebuild_sql_val(sql):
if sql is None or not DISABLE_VALUE:
return sql
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
sql['having'] = rebuild_condition_val(sql['having'])
sql['where'] = rebuild_condition_val(sql['where'])
sql['intersect'] = rebuild_sql_val(sql['intersect'])
sql['except'] = rebuild_sql_val(sql['except'])
sql['union'] = rebuild_sql_val(sql['union'])
return sql
# Rebuild SQL functions for foreign key evaluation
def build_valid_col_units(table_units, schema):
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
prefixs = [col_id[:-2] for col_id in col_ids]
valid_col_units= []
for value in schema.idMap.values():
if '.' in value and value[:value.index('.')] in prefixs:
valid_col_units.append(value)
return valid_col_units
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
if col_unit is None:
return col_unit
agg_id, col_id, distinct = col_unit
if col_id in kmap and col_id in valid_col_units:
col_id = kmap[col_id]
if DISABLE_DISTINCT:
distinct = None
return agg_id, col_id, distinct
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
if val_unit is None:
return val_unit
unit_op, col_unit1, col_unit2 = val_unit
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
return unit_op, col_unit1, col_unit2
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
if table_unit is None:
return table_unit
table_type, col_unit_or_sql = table_unit
if isinstance(col_unit_or_sql, tuple):
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
return table_type, col_unit_or_sql
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
if cond_unit is None:
return cond_unit
not_op, op_id, val_unit, val1, val2 = cond_unit
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
return not_op, op_id, val_unit, val1, val2
def rebuild_condition_col(valid_col_units, condition, kmap):
for idx in range(len(condition)):
if idx % 2 == 0:
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
return condition
def rebuild_select_col(valid_col_units, sel, kmap):
if sel is None:
return sel
distinct, _list = sel
new_list = []
for it in _list:
agg_id, val_unit = it
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
if DISABLE_DISTINCT:
distinct = None
return distinct, new_list
def rebuild_from_col(valid_col_units, from_, kmap):
if from_ is None:
return from_
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
return from_
def rebuild_group_by_col(valid_col_units, group_by, kmap):
if group_by is None:
return group_by
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
def rebuild_order_by_col(valid_col_units, order_by, kmap):
if order_by is None or len(order_by) == 0:
return order_by
direction, val_units = order_by
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
return direction, new_val_units
def rebuild_sql_col(valid_col_units, sql, kmap):
if sql is None:
return sql
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
return sql
def build_foreign_key_map(entry):
cols_orig = entry["column_names_original"]
tables_orig = entry["table_names_original"]
# rebuild cols corresponding to idmap in Schema
cols = []
for col_orig in cols_orig:
if col_orig[0] >= 0:
t = tables_orig[col_orig[0]]
c = col_orig[1]
cols.append("__" + t.lower() + "." + c.lower() + "__")
else:
cols.append("__all__")
def keyset_in_list(k1, k2, k_list):
for k_set in k_list:
if k1 in k_set or k2 in k_set:
return k_set
new_k_set = set()
k_list.append(new_k_set)
return new_k_set
foreign_key_list = []
foreign_keys = entry["foreign_keys"]
for fkey in foreign_keys:
key1, key2 = fkey
key_set = keyset_in_list(key1, key2, foreign_key_list)
key_set.add(key1)
key_set.add(key2)
foreign_key_map = {}
for key_set in foreign_key_list:
sorted_list = sorted(list(key_set))
midx = sorted_list[0]
for idx in sorted_list:
foreign_key_map[cols[idx]] = cols[midx]
return foreign_key_map
def build_foreign_key_map_from_json(table):
with open(table) as f:
data = json.load(f)
tables = {}
for entry in data:
tables[entry['db_id']] = build_foreign_key_map(entry)
return tables
def get_evaluation_result():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_dir=current_directory+"/data"
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
pred = current_directory+"/data/"+"pred_example_dusql.txt"
gold = current_directory+"/data/"+"gold_example_dusql.txt"
table= current_directory+"/data/"+"tables_dusql.json"
query_path=current_directory+"/data/"+config["domain"]+".txt"
etype="exec"
kmaps = build_foreign_key_map_from_json(table)
evaluate(gold, pred, db_dir, etype, kmaps,query_path)
def remove_unused_file():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
db_path=current_directory+"/data/"
db_file=db_path+config["domain"]+".db"
pred_file = current_directory+"/data/"+"pred_example_dusql.txt"
db_exist=os.path.exists(db_file)
if db_exist:
os.remove(db_file)
print("db_file removed!")
pred_exist=os.path.exists(pred_file)
if pred_exist:
os.remove(pred_file)
print("pred_file removed!")
if __name__ == "__main__":
build_table()
get_pred_result()
get_evaluation_result()
remove_unused_file()