mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-12 20:51:48 +00:00
(improvement) revise evaluation and fix null pointer (#715)
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import time
|
||||
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
import jwt
|
||||
import time
|
||||
import os
|
||||
import yaml
|
||||
from build_models import build,get_authorization
|
||||
|
||||
def read_query(input_path):
|
||||
result=[]
|
||||
@@ -24,13 +25,16 @@ def get_pred_sql(query,url,agentId,chatId,authorization,default_sql):
|
||||
header["Authorization"] =authorization
|
||||
try:
|
||||
result = requests.post(url=url, headers=header, json=data)
|
||||
print(result.json())
|
||||
print(result.json()["traceId"])
|
||||
if result.status_code == 200:
|
||||
data = result.json()["data"]
|
||||
selectedParses = data["selectedParses"]
|
||||
if selectedParses is not None and len(selectedParses) > 0:
|
||||
querySQL = selectedParses[0]["sqlInfo"]["querySQL"]
|
||||
querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "")
|
||||
return querySQL+'\n'
|
||||
if querySQL is not None:
|
||||
querySQL=querySQL.replace("`dusql`.", "").replace("dusql", "").replace("\n", "")
|
||||
return querySQL+'\n'
|
||||
return default_sql+'\n'
|
||||
except Exception as e:
|
||||
print(url)
|
||||
@@ -38,24 +42,22 @@ def get_pred_sql(query,url,agentId,chatId,authorization,default_sql):
|
||||
print(e)
|
||||
logging.info(e)
|
||||
return default_sql+'\n'
|
||||
def get_authorization():
|
||||
exp = time.time() + 1000
|
||||
token= jwt.encode({"token_user_name": "admin","exp": exp}, "secret", algorithm="HS512")
|
||||
return "Bearer "+token
|
||||
|
||||
def get_pred_result():
|
||||
current_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
config_file=current_directory+"/config/config.yaml"
|
||||
with open(config_file, 'r') as file:
|
||||
config = yaml.safe_load(file)
|
||||
input_path=current_directory+"/data/"+"internet.txt"
|
||||
pred_sql_path = current_directory+"/data/"+"pred_example_dusql.txt"
|
||||
input_path=current_directory+"/data/internet.txt"
|
||||
pred_sql_path = current_directory+"/data/pred_example_dusql.txt"
|
||||
pred_sql_exist=os.path.exists(pred_sql_path)
|
||||
if pred_sql_exist:
|
||||
os.remove(pred_sql_path)
|
||||
print("pred_sql_path removed!")
|
||||
agent_id=config["agent_id"]
|
||||
chat_id=config["chat_id"]
|
||||
dict_info=build()
|
||||
print(dict_info)
|
||||
agent_id=dict_info["agent_id"]
|
||||
chat_id=dict_info["chat_id"]
|
||||
url=config["url"]
|
||||
authorization=get_authorization()
|
||||
print(input_path)
|
||||
@@ -66,6 +68,7 @@ def get_pred_result():
|
||||
for i in range(0,len(questions)):
|
||||
pred_sql=get_pred_sql(questions[i],url,agent_id,chat_id,authorization,default_sql)
|
||||
pred_sql_list.append(pred_sql)
|
||||
time.sleep(30)
|
||||
write_sql(pred_sql_path, pred_sql_list)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user