mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 19:18:23 +00:00
(improvement)add corrector additional information switch and evalution time cost (#743)
This commit is contained in:
@@ -65,12 +65,19 @@ def get_pred_result():
|
||||
questions=read_query(input_path)
|
||||
pred_sql_list=[]
|
||||
default_sql="select * from tablea "
|
||||
time_cost=[]
|
||||
for i in range(0,len(questions)):
|
||||
start_time = time.time()
|
||||
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
|
||||
end_time = time.time()
|
||||
cost='%.3f'%(end_time-start_time)
|
||||
time_cost.append(cost)
|
||||
pred_sql_list.append(pred_sql)
|
||||
time.sleep(60)
|
||||
write_sql(pred_sql_path, pred_sql_list)
|
||||
|
||||
return [float(cost) for cost in time_cost]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("pred")
|
||||
|
||||
|
||||
@@ -482,7 +482,7 @@ def print_scores(scores, etype):
|
||||
print("{:20} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f} {:<20.3f}".format(type_, *this_scores))
|
||||
|
||||
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
||||
def evaluate(gold, predict, db_dir, etype, kmaps,query_path,time_cost):
|
||||
with open(gold) as f:
|
||||
glist = [l.strip().split('\t') for l in f.readlines() if len(l.strip()) > 0]
|
||||
|
||||
@@ -597,7 +597,11 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
||||
scores[level]['partial'][type_]['f1'] = \
|
||||
2.0 * scores[level]['partial'][type_]['acc'] * scores[level]['partial'][type_]['rec'] / (
|
||||
scores[level]['partial'][type_]['rec'] + scores[level]['partial'][type_]['acc'])
|
||||
|
||||
cost_dic = {}
|
||||
cost_dic["max_time"] = max(time_cost)
|
||||
cost_dic["min_time"] = min(time_cost)
|
||||
cost_dic["avg_time"] = sum(time_cost)/len(time_cost)
|
||||
log_list.append(cost_dic)
|
||||
print_scores(scores, etype)
|
||||
print(scores['all']['exec'])
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -608,7 +612,6 @@ def evaluate(gold, predict, db_dir, etype, kmaps,query_path):
|
||||
with open(file_name, 'w') as json_file:
|
||||
json.dump(log_list, json_file, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def eval_exec_match(db, p_str, g_str, pred, gold):
|
||||
"""
|
||||
return 1 if the values between prediction and gold are matching
|
||||
@@ -890,7 +893,7 @@ def build_foreign_key_map_from_json(table):
|
||||
tables[entry['db_id']] = build_foreign_key_map(entry)
|
||||
return tables
|
||||
|
||||
def get_evaluation_result():
|
||||
def get_evaluation_result(time_cost):
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
config_file=current_directory+"/config/config.yaml"
|
||||
with open(config_file, 'r') as file:
|
||||
@@ -905,7 +908,7 @@ def get_evaluation_result():
|
||||
etype="exec"
|
||||
kmaps = build_foreign_key_map_from_json(table)
|
||||
|
||||
evaluate(gold, pred, db_dir, etype, kmaps,query_path)
|
||||
evaluate(gold, pred, db_dir, etype, kmaps,query_path,time_cost)
|
||||
|
||||
def remove_unused_file():
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -927,8 +930,8 @@ def remove_unused_file():
|
||||
|
||||
if __name__ == "__main__":
|
||||
build_table()
|
||||
get_pred_result()
|
||||
get_evaluation_result()
|
||||
time_cost=get_pred_result()
|
||||
get_evaluation_result(time_cost)
|
||||
remove_unused_file()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user