mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-13 13:07:32 +00:00
(improvement) improve evalution accuracy (#727)
This commit is contained in:
@@ -522,14 +522,19 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
||||
p_sql = p_str
|
||||
|
||||
if etype in ["all", "exec"]:
|
||||
exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if not exec_score:
|
||||
result = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
#exec_score = eval_exec_match(db, p_str, g_str, p_sql, g_sql)
|
||||
if not result["equal"]:
|
||||
element={}
|
||||
element["query"]=questions[index]
|
||||
element["gold_sql"]=g_str
|
||||
element["pred_sql"]=p_str
|
||||
if "p_res_map" in result:
|
||||
element["p_res_map"]=result["p_res_map"]
|
||||
if "q_res_map" in result:
|
||||
element["q_res_map"]=result["q_res_map"]
|
||||
log_list.append(element)
|
||||
if exec_score:
|
||||
if result["equal"]:
|
||||
scores[hardness]['exec'] += 1.0
|
||||
scores['all']['exec'] += 1.0
|
||||
|
||||
@@ -609,6 +614,7 @@ def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
return 1 if the values between prediction and gold are matching
|
||||
in the corresponding index. Currently not support multiple col_unit(pairs).
|
||||
"""
|
||||
result={}
|
||||
conn = sqlite3.connect(db)
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
@@ -618,8 +624,10 @@ def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
for index in range(0,len(p_fields)):
|
||||
p_fields[index]=re.sub("t\d+.", "",p_fields[index].replace("`","").lower())
|
||||
p_res = cursor.fetchall()
|
||||
except:
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.info(e)
|
||||
result["equal"]=False
|
||||
return result
|
||||
|
||||
cursor.execute(g_str)
|
||||
q_res = cursor.fetchall()
|
||||
@@ -635,9 +643,15 @@ def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
|
||||
g_fields = parse_sql(g_str)
|
||||
|
||||
#print("p_res_map:{}".format(res_map(p_res, p_fields)))
|
||||
#print("q_res_map:{}".format(res_map(q_res, g_fields)))
|
||||
return res_map(p_res, p_fields) == res_map(q_res, g_fields)
|
||||
p_res_map=res_map(p_res, p_fields)
|
||||
q_res_map=res_map(q_res, g_fields)
|
||||
# print("p_res_map:{}".format(p_res_map))
|
||||
# print("q_res_map:{}".format(q_res_map))
|
||||
result["equal"]=(p_res_map==q_res_map)
|
||||
result["p_res_map"]=json.dumps(p_res_map, ensure_ascii=False)
|
||||
result["q_res_map"]=json.dumps(q_res_map, ensure_ascii=False)
|
||||
return result
|
||||
#return res_map(p_res, p_fields) == res_map(q_res, g_fields)
|
||||
|
||||
def parse_sql(sql):
|
||||
# 使用 sqlparse 库解析 SQL 查询语句
|
||||
|
||||
Reference in New Issue
Block a user