(improvement) improve evalution accuracy (#727)

This commit is contained in:
mainmain
2024-02-19 17:31:38 +08:00
committed by GitHub
parent fdb69547e6
commit 699a33b1c1
13 changed files with 112 additions and 53 deletions

View File

@@ -522,14 +522,19 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
p_sql = p_str
if etype in ["all", "exec"]:
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if not exec_score:
result = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
#exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
if not result["equal"]:
element={}
element["query"]=questions[index]
element["gold_sql"]=g_str
element["pred_sql"]=p_str
if "p_res_map" in result:
element["p_res_map"]=result["p_res_map"]
if "q_res_map" in result:
element["q_res_map"]=result["q_res_map"]
log_list.append(element)
if exec_score:
if result["equal"]:
scores[hardness]['exec'] += 1.0
scores['all']['exec'] += 1.0
@@ -609,6 +614,7 @@ def eval_exec_match(db, p_str, g_str, pred, gold):
return 1 if the values between prediction and gold are matching
in the corresponding index. Currently not support multiple col_unit(pairs).
"""
result={}
conn = sqlite3.connect(db)
cursor = conn.cursor()
try:
@@ -618,8 +624,10 @@ def eval_exec_match(db, p_str, g_str, pred, gold):
for index in range(0,len(p_fields)):
p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower())
p_res = cursor.fetchall()
except:
return False
except Exception as e:
logging.info(e)
result["equal"]=False
return result
cursor.execute(g_str)
q_res = cursor.fetchall()
@@ -635,9 +643,15 @@ def eval_exec_match(db, p_str, g_str, pred, gold):
g_fields = parse_sql(g_str)
#print("p_res_map:{}".format(res_map(p_res, p_fields)))
#print("q_res_map:{}".format(res_map(q_res, g_fields)))
return res_map(p_res, p_fields) == res_map(q_res, g_fields)
p_res_map=res_map(p_res, p_fields)
q_res_map=res_map(q_res, g_fields)
# print("p_res_map:{}".format(p_res_map))
# print("q_res_map:{}".format(q_res_map))
result["equal"]=(p_res_map==q_res_map)
result["p_res_map"]=json.dumps(p_res_map, ensure_ascii=False)
result["q_res_map"]=json.dumps(q_res_map, ensure_ascii=False)
return result
#return res_map(p_res, p_fields) == res_map(q_res, g_fields)
def parse_sql(sql):
# 使用 sqlparse 库解析 SQL 查询语句