mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
* [improvement] llm supports all models * [improvement] alias convert to SemanticParseInfo * [improvement] support join * [improvement] add evaluation.py * [improvement] add text2sql_evalution.py * [improvement] add text2sql_evalution.py * [improvement] add evalution * [improvement] add evalution * [improvement] add evalution --------- Co-authored-by: zuopengge <hwzuopengge@tencent.com>
923 lines
31 KiB
Python
923 lines
31 KiB
Python
################################
|
||
# val: number(float)/string(str)/sql(dict)
|
||
# col_unit: (agg_id, col_id, isDistinct(bool))
|
||
# val_unit: (unit_op, col_unit1, col_unit2)
|
||
# table_unit: (table_type, col_unit/sql)
|
||
# cond_unit: (not_op, op_id, val_unit, val1, val2)
|
||
# condition: [cond_unit1, 'and'/'or', cond_unit2, ...]
|
||
# sql {
|
||
# 'select': (isDistinct(bool), [(agg_id, val_unit), (agg_id, val_unit), ...])
|
||
# 'from': {'table_units': [table_unit1, table_unit2, ...], 'conds': condition}
|
||
# 'where': condition
|
||
# 'groupBy': [col_unit1, col_unit2, ...]
|
||
# 'orderBy': ('asc'/'desc', [val_unit1, val_unit2, ...])
|
||
# 'having': condition
|
||
# 'limit': None/limit value
|
||
# 'intersect': None/sql
|
||
# 'except': None/sql
|
||
# 'union': None/sql
|
||
# }
|
||
################################
|
||
|
||
from __future__ import print_function
|
||
|
||
import sqlparse
|
||
import logging
|
||
import os, sys
|
||
import json
|
||
import sqlite3
|
||
import traceback
|
||
import argparse
|
||
import yaml
|
||
import re
|
||
|
||
from process_sql import tokenize, get_schema, get_tables_with_alias, Schema, get_sql
|
||
from build_pred_result import read_query,get_pred_result
|
||
from build_tables import build_table
|
||
|
||
# Flag to disable value evaluation
|
||
DISABLE_VALUE = True
|
||
# Flag to disable distinct in select evaluation
|
||
DISABLE_DISTINCT = True
|
||
|
||
|
||
CLAUSE_KEYWORDS = ('select', 'from', 'where', 'group', 'order', 'limit', 'intersect', 'union', 'except')
|
||
JOIN_KEYWORDS = ('join', 'on', 'as')
|
||
|
||
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
|
||
UNIT_OPS = ('none', '-', '+', "*", '/')
|
||
AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
|
||
TABLE_TYPE = {
|
||
'sql': "sql",
|
||
'table_unit': "table_unit",
|
||
}
|
||
|
||
COND_OPS = ('and', 'or')
|
||
SQL_OPS = ('intersect', 'union', 'except')
|
||
ORDER_OPS = ('desc', 'asc')
|
||
|
||
|
||
HARDNESS = {
|
||
"component1": ('where', 'group', 'order', 'limit', 'join', 'or', 'like'),
|
||
"component2": ('except', 'union', 'intersect')
|
||
}
|
||
|
||
|
||
def condition_has_or(conds):
|
||
return 'or' in conds[1::2]
|
||
|
||
|
||
def condition_has_like(conds):
|
||
return WHERE_OPS.index('like') in [cond_unit[1] for cond_unit in conds[::2]]
|
||
|
||
|
||
def condition_has_sql(conds):
|
||
for cond_unit in conds[::2]:
|
||
val1, val2 = cond_unit[3], cond_unit[4]
|
||
if val1 is not None and type(val1) is dict:
|
||
return True
|
||
if val2 is not None and type(val2) is dict:
|
||
return True
|
||
return False
|
||
|
||
|
||
def val_has_op(val_unit):
|
||
return val_unit[0] != UNIT_OPS.index('none')
|
||
|
||
|
||
def has_agg(unit):
|
||
return unit[0] != AGG_OPS.index('none')
|
||
|
||
|
||
def accuracy(count, total):
|
||
if count == total:
|
||
return 1
|
||
return 0
|
||
|
||
|
||
def recall(count, total):
|
||
if count == total:
|
||
return 1
|
||
return 0
|
||
|
||
|
||
def F1(acc, rec):
|
||
if (acc + rec) == 0:
|
||
return 0
|
||
return (2. * acc * rec) / (acc + rec)
|
||
|
||
|
||
def get_scores(count, pred_total, label_total):
|
||
if pred_total != label_total:
|
||
return 0,0,0
|
||
elif count == pred_total:
|
||
return 1,1,1
|
||
return 0,0,0
|
||
|
||
|
||
def eval_sel(pred, label):
|
||
pred_sel = pred['select'][1]
|
||
label_sel = label['select'][1]
|
||
label_wo_agg = [unit[1] for unit in label_sel]
|
||
pred_total = len(pred_sel)
|
||
label_total = len(label_sel)
|
||
cnt = 0
|
||
cnt_wo_agg = 0
|
||
|
||
for unit in pred_sel:
|
||
if unit in label_sel:
|
||
cnt += 1
|
||
label_sel.remove(unit)
|
||
if unit[1] in label_wo_agg:
|
||
cnt_wo_agg += 1
|
||
label_wo_agg.remove(unit[1])
|
||
|
||
return label_total, pred_total, cnt, cnt_wo_agg
|
||
|
||
|
||
def eval_where(pred, label):
|
||
pred_conds = [unit for unit in pred['where'][::2]]
|
||
label_conds = [unit for unit in label['where'][::2]]
|
||
label_wo_agg = [unit[2] for unit in label_conds]
|
||
pred_total = len(pred_conds)
|
||
label_total = len(label_conds)
|
||
cnt = 0
|
||
cnt_wo_agg = 0
|
||
|
||
for unit in pred_conds:
|
||
if unit in label_conds:
|
||
cnt += 1
|
||
label_conds.remove(unit)
|
||
if unit[2] in label_wo_agg:
|
||
cnt_wo_agg += 1
|
||
label_wo_agg.remove(unit[2])
|
||
|
||
return label_total, pred_total, cnt, cnt_wo_agg
|
||
|
||
|
||
def eval_group(pred, label):
|
||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||
label_cols = [unit[1] for unit in label['groupBy']]
|
||
pred_total = len(pred_cols)
|
||
label_total = len(label_cols)
|
||
cnt = 0
|
||
pred_cols = [pred.split(".")[1] if "." in pred else pred for pred in pred_cols]
|
||
label_cols = [label.split(".")[1] if "." in label else label for label in label_cols]
|
||
for col in pred_cols:
|
||
if col in label_cols:
|
||
cnt += 1
|
||
label_cols.remove(col)
|
||
return label_total, pred_total, cnt
|
||
|
||
|
||
def eval_having(pred, label):
|
||
pred_total = label_total = cnt = 0
|
||
if len(pred['groupBy']) > 0:
|
||
pred_total = 1
|
||
if len(label['groupBy']) > 0:
|
||
label_total = 1
|
||
|
||
pred_cols = [unit[1] for unit in pred['groupBy']]
|
||
label_cols = [unit[1] for unit in label['groupBy']]
|
||
if pred_total == label_total == 1 \
|
||
and pred_cols == label_cols \
|
||
and pred['having'] == label['having']:
|
||
cnt = 1
|
||
|
||
return label_total, pred_total, cnt
|
||
|
||
|
||
def eval_order(pred, label):
|
||
pred_total = label_total = cnt = 0
|
||
if len(pred['orderBy']) > 0:
|
||
pred_total = 1
|
||
if len(label['orderBy']) > 0:
|
||
label_total = 1
|
||
if len(label['orderBy']) > 0 and pred['orderBy'] == label['orderBy'] and \
|
||
((pred['limit'] is None and label['limit'] is None) or (pred['limit'] is not None and label['limit'] is not None)):
|
||
cnt = 1
|
||
return label_total, pred_total, cnt
|
||
|
||
|
||
def eval_and_or(pred, label):
|
||
pred_ao = pred['where'][1::2]
|
||
label_ao = label['where'][1::2]
|
||
pred_ao = set(pred_ao)
|
||
label_ao = set(label_ao)
|
||
|
||
if pred_ao == label_ao:
|
||
return 1,1,1
|
||
return len(pred_ao),len(label_ao),0
|
||
|
||
|
||
def get_nestedSQL(sql):
|
||
nested = []
|
||
for cond_unit in sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]:
|
||
if type(cond_unit[3]) is dict:
|
||
nested.append(cond_unit[3])
|
||
if type(cond_unit[4]) is dict:
|
||
nested.append(cond_unit[4])
|
||
if sql['intersect'] is not None:
|
||
nested.append(sql['intersect'])
|
||
if sql['except'] is not None:
|
||
nested.append(sql['except'])
|
||
if sql['union'] is not None:
|
||
nested.append(sql['union'])
|
||
return nested
|
||
|
||
|
||
def eval_nested(pred, label):
|
||
label_total = 0
|
||
pred_total = 0
|
||
cnt = 0
|
||
if pred is not None:
|
||
pred_total += 1
|
||
if label is not None:
|
||
label_total += 1
|
||
if pred is not None and label is not None:
|
||
cnt += Evaluator().eval_exact_match(pred, label)
|
||
return label_total, pred_total, cnt
|
||
|
||
|
||
def eval_IUEN(pred, label):
|
||
lt1, pt1, cnt1 = eval_nested(pred['intersect'], label['intersect'])
|
||
lt2, pt2, cnt2 = eval_nested(pred['except'], label['except'])
|
||
lt3, pt3, cnt3 = eval_nested(pred['union'], label['union'])
|
||
label_total = lt1 + lt2 + lt3
|
||
pred_total = pt1 + pt2 + pt3
|
||
cnt = cnt1 + cnt2 + cnt3
|
||
return label_total, pred_total, cnt
|
||
|
||
|
||
def get_keywords(sql):
|
||
res = set()
|
||
if len(sql['where']) > 0:
|
||
res.add('where')
|
||
if len(sql['groupBy']) > 0:
|
||
res.add('group')
|
||
if len(sql['having']) > 0:
|
||
res.add('having')
|
||
if len(sql['orderBy']) > 0:
|
||
res.add(sql['orderBy'][0])
|
||
res.add('order')
|
||
if sql['limit'] is not None:
|
||
res.add('limit')
|
||
if sql['except'] is not None:
|
||
res.add('except')
|
||
if sql['union'] is not None:
|
||
res.add('union')
|
||
if sql['intersect'] is not None:
|
||
res.add('intersect')
|
||
|
||
# or keyword
|
||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||
if len([token for token in ao if token == 'or']) > 0:
|
||
res.add('or')
|
||
|
||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||
# not keyword
|
||
if len([cond_unit for cond_unit in cond_units if cond_unit[0]]) > 0:
|
||
res.add('not')
|
||
|
||
# in keyword
|
||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('in')]) > 0:
|
||
res.add('in')
|
||
|
||
# like keyword
|
||
if len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')]) > 0:
|
||
res.add('like')
|
||
|
||
return res
|
||
|
||
|
||
def eval_keywords(pred, label):
|
||
pred_keywords = get_keywords(pred)
|
||
label_keywords = get_keywords(label)
|
||
pred_total = len(pred_keywords)
|
||
label_total = len(label_keywords)
|
||
cnt = 0
|
||
|
||
for k in pred_keywords:
|
||
if k in label_keywords:
|
||
cnt += 1
|
||
return label_total, pred_total, cnt
|
||
|
||
|
||
def count_agg(units):
|
||
return len([unit for unit in units if has_agg(unit)])
|
||
|
||
|
||
def count_component1(sql):
|
||
count = 0
|
||
if len(sql['where']) > 0:
|
||
count += 1
|
||
if len(sql['groupBy']) > 0:
|
||
count += 1
|
||
if len(sql['orderBy']) > 0:
|
||
count += 1
|
||
if sql['limit'] is not None:
|
||
count += 1
|
||
if len(sql['from']['table_units']) > 0: # JOIN
|
||
count += len(sql['from']['table_units']) - 1
|
||
|
||
ao = sql['from']['conds'][1::2] + sql['where'][1::2] + sql['having'][1::2]
|
||
count += len([token for token in ao if token == 'or'])
|
||
cond_units = sql['from']['conds'][::2] + sql['where'][::2] + sql['having'][::2]
|
||
count += len([cond_unit for cond_unit in cond_units if cond_unit[1] == WHERE_OPS.index('like')])
|
||
|
||
return count
|
||
|
||
|
||
def count_component2(sql):
|
||
nested = get_nestedSQL(sql)
|
||
return len(nested)
|
||
|
||
|
||
def count_others(sql):
|
||
count = 0
|
||
# number of aggregation
|
||
agg_count = count_agg(sql['select'][1])
|
||
agg_count += count_agg(sql['where'][::2])
|
||
agg_count += count_agg(sql['groupBy'])
|
||
if len(sql['orderBy']) > 0:
|
||
agg_count += count_agg([unit[1] for unit in sql['orderBy'][1] if unit[1]] +
|
||
[unit[2] for unit in sql['orderBy'][1] if unit[2]])
|
||
agg_count += count_agg(sql['having'])
|
||
if agg_count > 1:
|
||
count += 1
|
||
|
||
# number of select columns
|
||
if len(sql['select'][1]) > 1:
|
||
count += 1
|
||
|
||
# number of where conditions
|
||
if len(sql['where']) > 1:
|
||
count += 1
|
||
|
||
# number of group by clauses
|
||
if len(sql['groupBy']) > 1:
|
||
count += 1
|
||
|
||
return count
|
||
|
||
|
||
class Evaluator:
|
||
"""A simple evaluator"""
|
||
def __init__(self):
|
||
self.partial_scores = None
|
||
|
||
def eval_hardness(self, sql):
|
||
count_comp1_ = count_component1(sql)
|
||
count_comp2_ = count_component2(sql)
|
||
count_others_ = count_others(sql)
|
||
|
||
if count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ == 0:
|
||
return "easy"
|
||
elif (count_others_ <= 2 and count_comp1_ <= 1 and count_comp2_ == 0) or \
|
||
(count_comp1_ <= 2 and count_others_ < 2 and count_comp2_ == 0):
|
||
return "medium"
|
||
elif (count_others_ > 2 and count_comp1_ <= 2 and count_comp2_ == 0) or \
|
||
(2 < count_comp1_ <= 3 and count_others_ <= 2 and count_comp2_ == 0) or \
|
||
(count_comp1_ <= 1 and count_others_ == 0 and count_comp2_ <= 1):
|
||
return "hard"
|
||
else:
|
||
return "extra"
|
||
|
||
def eval_exact_match(self, pred, label):
|
||
partial_scores = self.eval_partial_match(pred, label)
|
||
self.partial_scores = partial_scores
|
||
|
||
for _, score in partial_scores.items():
|
||
if score['f1'] != 1:
|
||
return 0
|
||
if len(label['from']['table_units']) > 0:
|
||
label_tables = sorted(label['from']['table_units'])
|
||
pred_tables = sorted(pred['from']['table_units'])
|
||
return label_tables == pred_tables
|
||
return 1
|
||
|
||
def eval_partial_match(self, pred, label):
|
||
res = {}
|
||
|
||
label_total, pred_total, cnt, cnt_wo_agg = eval_sel(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['select'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||
res['select(no AGG)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt, cnt_wo_agg = eval_where(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['where'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
acc, rec, f1 = get_scores(cnt_wo_agg, pred_total, label_total)
|
||
res['where(no OP)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt = eval_group(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['group(no Having)'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt = eval_having(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['group'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt = eval_order(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['order'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt = eval_and_or(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['and/or'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt = eval_IUEN(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['IUEN'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
label_total, pred_total, cnt = eval_keywords(pred, label)
|
||
acc, rec, f1 = get_scores(cnt, pred_total, label_total)
|
||
res['keywords'] = {'acc': acc, 'rec': rec, 'f1': f1,'label_total':label_total,'pred_total':pred_total}
|
||
|
||
return res
|
||
|
||
|
||
def isValidSQL(sql, db):
|
||
conn = sqlite3.connect(db)
|
||
cursor = conn.cursor()
|
||
try:
|
||
cursor.execute(sql)
|
||
except:
|
||
return False
|
||
return True
|
||
|
||
|
||
def print_scores(scores, etype):
|
||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||
|
||
print("{:20} {:20} {:20} {:20} {:20} {:20}".format("", *levels))
|
||
counts = [scores[level]['count'] for level in levels]
|
||
print("{:20} {:<20d} {:<20d} {:<20d} {:<20d} {:<20d}".format("count", *counts))
|
||
|
||
if etype in ["all", "exec"]:
|
||
print('===================== EXECUTION ACCURACY =====================')
|
||
this_scores = [scores[level]['exec'] for level in levels]
|
||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("execution", *this_scores))
|
||
|
||
if etype in ["all", "match"]:
|
||
print('\n====================== EXACT MATCHING ACCURACY =====================')
|
||
exact_scores = [scores[level]['exact'] for level in levels]
|
||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format("exact match", *exact_scores))
|
||
print('\n---------------------PARTIAL MATCHING ACCURACY----------------------')
|
||
for type_ in partial_types:
|
||
this_scores = [scores[level]['partial'][type_]['acc'] for level in levels]
|
||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||
|
||
print('---------------------- PARTIAL MATCHING RECALL ----------------------')
|
||
for type_ in partial_types:
|
||
this_scores = [scores[level]['partial'][type_]['rec'] for level in levels]
|
||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||
|
||
print('---------------------- PARTIAL MATCHING F1 --------------------------')
|
||
for type_ in partial_types:
|
||
this_scores = [scores[level]['partial'][type_]['f1'] for level in levels]
|
||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||
|
||
|
||
def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
||
with open(gold) as f:
|
||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||
|
||
with open(predict) as f:
|
||
plist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||
# plist = [("select max(Share),min(Share) from performance where Type != 'terminal'", "orchestra")]
|
||
# glist = [("SELECT max(SHARE) , min(SHARE) FROM performance WHERE TYPE != 'Live final'", "orchestra")]
|
||
evaluator = Evaluator()
|
||
#print(plist)
|
||
levels = ['easy', 'medium', 'hard', 'extra', 'all']
|
||
partial_types = ['select', 'select(no AGG)', 'where', 'where(no OP)', 'group(no Having)',
|
||
'group', 'order', 'and/or', 'IUEN', 'keywords']
|
||
entries = []
|
||
scores = {}
|
||
log_list=[]
|
||
for level in levels:
|
||
scores[level] = {'count': 0, 'partial': {}, 'exact': 0.}
|
||
scores[level]['exec'] = 0
|
||
for type_ in partial_types:
|
||
scores[level]['partial'][type_] = {'acc': 0., 'rec': 0., 'f1': 0.,'acc_count':0,'rec_count':0}
|
||
|
||
eval_err_num = 0
|
||
questions=read_query(query_path)
|
||
index=0
|
||
for p, g in zip(plist, glist):
|
||
p_str = p[0]
|
||
g_str, db = g
|
||
db_name = db
|
||
# db = os.path.join(db_dir, db, db + ".sqlite")
|
||
db = os.path.join(db_dir,db + ".db")
|
||
print(db)
|
||
schema = Schema(get_schema(db))
|
||
g_sql = get_sql(schema, g_str)
|
||
hardness = evaluator.eval_hardness(g_sql)
|
||
scores[hardness]['count'] += 1
|
||
scores['all']['count'] += 1
|
||
|
||
p_sql = p_str
|
||
|
||
if etype in ["all", "exec"]:
|
||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||
if not exec_score:
|
||
element={}
|
||
element["query"]=questions[index]
|
||
element["gold_sql"]=g_str
|
||
element["pred_sql"]=p_str
|
||
log_list.append(element)
|
||
if exec_score:
|
||
scores[hardness]['exec'] += 1.0
|
||
scores['all']['exec'] += 1.0
|
||
|
||
if etype in ["all", "match"]:
|
||
exact_score = evaluator.eval_exact_match(p_sql, g_sql)
|
||
partial_scores = evaluator.partial_scores
|
||
if exact_score == 0:
|
||
print("{} pred: {}".format(hardness,p_str))
|
||
print("{} gold: {}".format(hardness,g_str))
|
||
print("")
|
||
scores[hardness]['exact'] += exact_score
|
||
scores['all']['exact'] += exact_score
|
||
for type_ in partial_types:
|
||
if partial_scores[type_]['pred_total'] > 0:
|
||
scores[hardness]['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||
scores[hardness]['partial'][type_]['acc_count'] += 1
|
||
if partial_scores[type_]['label_total'] > 0:
|
||
scores[hardness]['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||
scores[hardness]['partial'][type_]['rec_count'] += 1
|
||
scores[hardness]['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||
if partial_scores[type_]['pred_total'] > 0:
|
||
scores['all']['partial'][type_]['acc'] += partial_scores[type_]['acc']
|
||
scores['all']['partial'][type_]['acc_count'] += 1
|
||
if partial_scores[type_]['label_total'] > 0:
|
||
scores['all']['partial'][type_]['rec'] += partial_scores[type_]['rec']
|
||
scores['all']['partial'][type_]['rec_count'] += 1
|
||
scores['all']['partial'][type_]['f1'] += partial_scores[type_]['f1']
|
||
|
||
entries.append({
|
||
'predictSQL': p_str,
|
||
'goldSQL': g_str,
|
||
'hardness': hardness,
|
||
'exact': exact_score,
|
||
'partial': partial_scores
|
||
})
|
||
index=index+1
|
||
|
||
for level in levels:
|
||
if scores[level]['count'] == 0:
|
||
continue
|
||
if etype in ["all", "exec"]:
|
||
scores[level]['exec'] /= scores[level]['count']
|
||
|
||
if etype in ["all", "match"]:
|
||
scores[level]['exact'] /= scores[level]['count']
|
||
for type_ in partial_types:
|
||
if scores[level]['partial'][type_]['acc_count'] == 0:
|
||
scores[level]['partial'][type_]['acc'] = 0
|
||
else:
|
||
scores[level]['partial'][type_]['acc'] = scores[level]['partial'][type_]['acc'] / \
|
||
scores[level]['partial'][type_]['acc_count'] * 1.0
|
||
if scores[level]['partial'][type_]['rec_count'] == 0:
|
||
scores[level]['partial'][type_]['rec'] = 0
|
||
else:
|
||
scores[level]['partial'][type_]['rec'] = scores[level]['partial'][type_]['rec'] / \
|
||
scores[level]['partial'][type_]['rec_count'] * 1.0
|
||
if scores[level]['partial'][type_]['acc'] == 0 and scores[level]['partial'][type_]['rec'] == 0:
|
||
scores[level]['partial'][type_]['f1'] = 1
|
||
else:
|
||
scores[level]['partial'][type_]['f1'] = \
|
||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||
|
||
print_scores(scores, etype)
|
||
print(scores['all']['exec'])
|
||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||
file_name=current_directory+"/eval.json"
|
||
json_exist=os.path.exists(file_name)
|
||
if json_exist:
|
||
os.remove(file_name)
|
||
with open(file_name, 'w') as json_file:
|
||
json.dump(log_list, json_file, indent=4, ensure_ascii=False)
|
||
|
||
|
||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||
"""
|
||
return 1 if the values between prediction and gold are matching
|
||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||
"""
|
||
conn = sqlite3.connect(db)
|
||
cursor = conn.cursor()
|
||
try:
|
||
cursor.execute(p_str)
|
||
columns_tuple = cursor.description
|
||
p_fields = [field_tuple[0] for field_tuple in columns_tuple]
|
||
for index in range(0,len(p_fields)):
|
||
p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower())
|
||
p_res = cursor.fetchall()
|
||
except:
|
||
return False
|
||
|
||
cursor.execute(g_str)
|
||
q_res = cursor.fetchall()
|
||
|
||
def res_map(res, p_fields):
|
||
rmap = {}
|
||
for i in range(0,len(p_fields)):
|
||
if p_fields[i] != "sys_imp_date":
|
||
value_list= [r[i] for r in res]
|
||
value_list.sort()
|
||
rmap[p_fields[i]] =value_list
|
||
return rmap
|
||
|
||
g_fields = parse_sql(g_str)
|
||
|
||
#print("p_res_map:{}".format(res_map(p_res, p_fields)))
|
||
#print("q_res_map:{}".format(res_map(q_res, g_fields)))
|
||
return res_map(p_res, p_fields) == res_map(q_res, g_fields)
|
||
|
||
def parse_sql(sql):
|
||
# 使用 sqlparse 库解析 SQL 查询语句
|
||
parsed = sqlparse.parse(sql)[0]
|
||
|
||
# 获取查询类型(SELECT、INSERT、UPDATE 或 DELETE)
|
||
query_type = parsed.get_type()
|
||
|
||
# 获取查询目标(表名、字段列表、值列表等)
|
||
if query_type == 'SELECT':
|
||
target = parse_select(parsed)
|
||
else:
|
||
target = None
|
||
|
||
return target
|
||
|
||
|
||
def parse_select(parsed):
|
||
# 获取字段列表
|
||
fields = []
|
||
for token in parsed.tokens:
|
||
#
|
||
if isinstance(token, sqlparse.sql.IdentifierList):
|
||
for identifier in token.get_identifiers():
|
||
fields.append(identifier.value.replace("`", "")
|
||
.replace("T1.", "").replace("T2.", "")
|
||
.replace("T3.", "").replace("T4.", "")
|
||
.replace("T5.", "").replace("T6.", ""))
|
||
if(len(fields)):
|
||
break
|
||
return fields
|
||
|
||
# Rebuild SQL functions for value evaluation
|
||
def rebuild_cond_unit_val(cond_unit):
|
||
if cond_unit is None or not DISABLE_VALUE:
|
||
return cond_unit
|
||
|
||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||
if type(val1) is not dict:
|
||
val1 = None
|
||
else:
|
||
val1 = rebuild_sql_val(val1)
|
||
if type(val2) is not dict:
|
||
val2 = None
|
||
else:
|
||
val2 = rebuild_sql_val(val2)
|
||
return not_op, op_id, val_unit, val1, val2
|
||
|
||
|
||
def rebuild_condition_val(condition):
|
||
if condition is None or not DISABLE_VALUE:
|
||
return condition
|
||
|
||
res = []
|
||
for idx, it in enumerate(condition):
|
||
if idx % 2 == 0:
|
||
res.append(rebuild_cond_unit_val(it))
|
||
else:
|
||
res.append(it)
|
||
return res
|
||
|
||
|
||
def rebuild_sql_val(sql):
|
||
if sql is None or not DISABLE_VALUE:
|
||
return sql
|
||
|
||
sql['from']['conds'] = rebuild_condition_val(sql['from']['conds'])
|
||
sql['having'] = rebuild_condition_val(sql['having'])
|
||
sql['where'] = rebuild_condition_val(sql['where'])
|
||
sql['intersect'] = rebuild_sql_val(sql['intersect'])
|
||
sql['except'] = rebuild_sql_val(sql['except'])
|
||
sql['union'] = rebuild_sql_val(sql['union'])
|
||
|
||
return sql
|
||
|
||
|
||
# Rebuild SQL functions for foreign key evaluation
|
||
def build_valid_col_units(table_units, schema):
|
||
col_ids = [table_unit[1] for table_unit in table_units if table_unit[0] == TABLE_TYPE['table_unit']]
|
||
prefixs = [col_id[:-2] for col_id in col_ids]
|
||
valid_col_units= []
|
||
for value in schema.idMap.values():
|
||
if '.' in value and value[:value.index('.')] in prefixs:
|
||
valid_col_units.append(value)
|
||
return valid_col_units
|
||
|
||
|
||
def rebuild_col_unit_col(valid_col_units, col_unit, kmap):
|
||
if col_unit is None:
|
||
return col_unit
|
||
|
||
agg_id, col_id, distinct = col_unit
|
||
if col_id in kmap and col_id in valid_col_units:
|
||
col_id = kmap[col_id]
|
||
if DISABLE_DISTINCT:
|
||
distinct = None
|
||
return agg_id, col_id, distinct
|
||
|
||
|
||
def rebuild_val_unit_col(valid_col_units, val_unit, kmap):
|
||
if val_unit is None:
|
||
return val_unit
|
||
|
||
unit_op, col_unit1, col_unit2 = val_unit
|
||
col_unit1 = rebuild_col_unit_col(valid_col_units, col_unit1, kmap)
|
||
col_unit2 = rebuild_col_unit_col(valid_col_units, col_unit2, kmap)
|
||
return unit_op, col_unit1, col_unit2
|
||
|
||
|
||
def rebuild_table_unit_col(valid_col_units, table_unit, kmap):
|
||
if table_unit is None:
|
||
return table_unit
|
||
|
||
table_type, col_unit_or_sql = table_unit
|
||
if isinstance(col_unit_or_sql, tuple):
|
||
col_unit_or_sql = rebuild_col_unit_col(valid_col_units, col_unit_or_sql, kmap)
|
||
return table_type, col_unit_or_sql
|
||
|
||
|
||
def rebuild_cond_unit_col(valid_col_units, cond_unit, kmap):
|
||
if cond_unit is None:
|
||
return cond_unit
|
||
|
||
not_op, op_id, val_unit, val1, val2 = cond_unit
|
||
val_unit = rebuild_val_unit_col(valid_col_units, val_unit, kmap)
|
||
return not_op, op_id, val_unit, val1, val2
|
||
|
||
|
||
def rebuild_condition_col(valid_col_units, condition, kmap):
|
||
for idx in range(len(condition)):
|
||
if idx % 2 == 0:
|
||
condition[idx] = rebuild_cond_unit_col(valid_col_units, condition[idx], kmap)
|
||
return condition
|
||
|
||
|
||
def rebuild_select_col(valid_col_units, sel, kmap):
|
||
if sel is None:
|
||
return sel
|
||
distinct, _list = sel
|
||
new_list = []
|
||
for it in _list:
|
||
agg_id, val_unit = it
|
||
new_list.append((agg_id, rebuild_val_unit_col(valid_col_units, val_unit, kmap)))
|
||
if DISABLE_DISTINCT:
|
||
distinct = None
|
||
return distinct, new_list
|
||
|
||
|
||
def rebuild_from_col(valid_col_units, from_, kmap):
|
||
if from_ is None:
|
||
return from_
|
||
|
||
from_['table_units'] = [rebuild_table_unit_col(valid_col_units, table_unit, kmap) for table_unit in from_['table_units']]
|
||
from_['conds'] = rebuild_condition_col(valid_col_units, from_['conds'], kmap)
|
||
return from_
|
||
|
||
|
||
def rebuild_group_by_col(valid_col_units, group_by, kmap):
|
||
if group_by is None:
|
||
return group_by
|
||
|
||
return [rebuild_col_unit_col(valid_col_units, col_unit, kmap) for col_unit in group_by]
|
||
|
||
|
||
def rebuild_order_by_col(valid_col_units, order_by, kmap):
|
||
if order_by is None or len(order_by) == 0:
|
||
return order_by
|
||
|
||
direction, val_units = order_by
|
||
new_val_units = [rebuild_val_unit_col(valid_col_units, val_unit, kmap) for val_unit in val_units]
|
||
return direction, new_val_units
|
||
|
||
|
||
def rebuild_sql_col(valid_col_units, sql, kmap):
|
||
if sql is None:
|
||
return sql
|
||
|
||
sql['select'] = rebuild_select_col(valid_col_units, sql['select'], kmap)
|
||
sql['from'] = rebuild_from_col(valid_col_units, sql['from'], kmap)
|
||
sql['where'] = rebuild_condition_col(valid_col_units, sql['where'], kmap)
|
||
sql['groupBy'] = rebuild_group_by_col(valid_col_units, sql['groupBy'], kmap)
|
||
sql['orderBy'] = rebuild_order_by_col(valid_col_units, sql['orderBy'], kmap)
|
||
sql['having'] = rebuild_condition_col(valid_col_units, sql['having'], kmap)
|
||
sql['intersect'] = rebuild_sql_col(valid_col_units, sql['intersect'], kmap)
|
||
sql['except'] = rebuild_sql_col(valid_col_units, sql['except'], kmap)
|
||
sql['union'] = rebuild_sql_col(valid_col_units, sql['union'], kmap)
|
||
|
||
return sql
|
||
|
||
|
||
def build_foreign_key_map(entry):
|
||
cols_orig = entry["column_names_original"]
|
||
tables_orig = entry["table_names_original"]
|
||
|
||
# rebuild cols corresponding to idmap in Schema
|
||
cols = []
|
||
for col_orig in cols_orig:
|
||
if col_orig[0] >= 0:
|
||
t = tables_orig[col_orig[0]]
|
||
c = col_orig[1]
|
||
cols.append("__" + t.lower() + "." + c.lower() + "__")
|
||
else:
|
||
cols.append("__all__")
|
||
|
||
def keyset_in_list(k1, k2, k_list):
|
||
for k_set in k_list:
|
||
if k1 in k_set or k2 in k_set:
|
||
return k_set
|
||
new_k_set = set()
|
||
k_list.append(new_k_set)
|
||
return new_k_set
|
||
|
||
foreign_key_list = []
|
||
foreign_keys = entry["foreign_keys"]
|
||
for fkey in foreign_keys:
|
||
key1, key2 = fkey
|
||
key_set = keyset_in_list(key1, key2, foreign_key_list)
|
||
key_set.add(key1)
|
||
key_set.add(key2)
|
||
|
||
foreign_key_map = {}
|
||
for key_set in foreign_key_list:
|
||
sorted_list = sorted(list(key_set))
|
||
midx = sorted_list[0]
|
||
for idx in sorted_list:
|
||
foreign_key_map[cols[idx]] = cols[midx]
|
||
|
||
return foreign_key_map
|
||
|
||
|
||
def build_foreign_key_map_from_json(table):
|
||
with open(table) as f:
|
||
data = json.load(f)
|
||
tables = {}
|
||
for entry in data:
|
||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||
return tables
|
||
|
||
def get_evaluation_result():
|
||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||
config_file=current_directory+"/config/config.yaml"
|
||
with open(config_file, 'r') as file:
|
||
config = yaml.safe_load(file)
|
||
db_dir=current_directory+"/data"
|
||
db_path=current_directory+"/data/"
|
||
db_file=db_path+config["domain"]+".db"
|
||
pred = current_directory+"/data/"+"pred_example_dusql.txt"
|
||
gold = current_directory+"/data/"+"gold_example_dusql.txt"
|
||
table= current_directory+"/data/"+"tables_dusql.json"
|
||
query_path=current_directory+"/data/"+config["domain"]+".txt"
|
||
etype="exec"
|
||
kmaps = build_foreign_key_map_from_json(table)
|
||
|
||
evaluate(gold, pred, db_dir, etype, kmaps,query_path)
|
||
|
||
def remove_unused_file():
|
||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||
config_file=current_directory+"/config/config.yaml"
|
||
with open(config_file, 'r') as file:
|
||
config = yaml.safe_load(file)
|
||
db_path=current_directory+"/data/"
|
||
db_file=db_path+config["domain"]+".db"
|
||
pred_file = current_directory+"/data/"+"pred_example_dusql.txt"
|
||
|
||
db_exist=os.path.exists(db_file)
|
||
if db_exist:
|
||
os.remove(db_file)
|
||
print("db_file removed!")
|
||
pred_exist=os.path.exists(pred_file)
|
||
if pred_exist:
|
||
os.remove(pred_file)
|
||
print("pred_file removed!")
|
||
|
||
if __name__ == "__main__":
|
||
build_table()
|
||
get_pred_result()
|
||
get_evaluation_result()
|
||
remove_unused_file()
|
||
|
||
|
||
|