Files
supersonic/evaluation/build_pred_result.py
mainmain 7949efea75 (improvement)(Headless) corrector supports subselect sql (#1006)
* [improvement] corrector support subselect sql

---------

Co-authored-by: zuopengge
2024-05-17 10:01:49 +08:00

86 lines
2.7 KiB
Python

import time
import requests
import logging
import json
import os
import yaml
from build_models import build,get_authorization
def read_query(input_path):
result=[]
with open(input_path, "r") as f:
for line in f.readlines():
line = line.strip('\n')
result.append(line)
return result
def write_sql(output_path,result):
file = open(output_path, mode='a')
file.writelines(result)
file.close()
def get_pred_sql(query,url,agentId,chatId,authorization,default_sql):
url=url+"/api/chat/query/parse"
data = {"agentId": agentId, "chatId":chatId,"queryText":query}
header = {}
header["Authorization"] =authorization
try:
result = requests.post(url=url, headers=header, json=data)
#print(result.json())
print(result.json()["traceId"])
if result.status_code == 200:
data = result.json()["data"]
selectedParses = data["selectedParses"]
if selectedParses is not None and len(selectedParses) > 0:
querySQL = selectedParses[0]["sqlInfo"]["querySQL"]
if querySQL is not None:
querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "")
return querySQL+'\n'
return default_sql+'\n'
except Exception as e:
print(url)
print(result.json())
print(e)
logging.info(e)
return default_sql+'\n'
def get_pred_result():
current_directory = os.path.dirname(os.path.abspath(__file__))
config_file=current_directory+"/config/config.yaml"
with open(config_file, 'r') as file:
config = yaml.safe_load(file)
input_path=current_directory+"/data/internet.txt"
pred_sql_path = current_directory+"/data/pred_example_dusql.txt"
pred_sql_exist=os.path.exists(pred_sql_path)
if pred_sql_exist:
os.remove(pred_sql_path)
print("pred_sql_path removed!")
dict_info=build()
print(dict_info)
agent_id=dict_info["agent_id"]
chat_id=dict_info["chat_id"]
url=config["url"]
authorization=get_authorization()
print(input_path)
print(pred_sql_path)
questions=read_query(input_path)
pred_sql_list=[]
default_sql="select * from tablea "
time_cost=[]
time.sleep(60)
for i in range(0,len(questions)):
start_time = time.time()
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
end_time = time.time()
cost='%.3f'%(end_time-start_time)
time_cost.append(cost)
pred_sql_list.append(pred_sql)
time.sleep(3)
write_sql(pred_sql_path, pred_sql_list)
return [float(cost) for cost in time_cost]
if __name__ == "__main__":
print("pred")