mirror of
https://github.com/tencentmusic/supersonic.git
synced 2025-12-10 11:07:06 +00:00
(improve)(benchmark) improve benchmark, add analysis of parsing results (#2215)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,3 +21,4 @@ __pycache__/
|
|||||||
/dict
|
/dict
|
||||||
assembly/build/*-SNAPSHOT
|
assembly/build/*-SNAPSHOT
|
||||||
**/node_modules/
|
**/node_modules/
|
||||||
|
benchmark/res/
|
||||||
@@ -15,6 +15,68 @@ import requests
|
|||||||
import time
|
import time
|
||||||
import jwt
|
import jwt
|
||||||
import traceback
|
import traceback
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class DataFrameAppender:
|
||||||
|
def __init__(self,file_name = "output"):
|
||||||
|
# 定义表头
|
||||||
|
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
|
||||||
|
# 创建只有表头的 DataFrame
|
||||||
|
self.df = pd.DataFrame(columns=columns)
|
||||||
|
self.file_name = file_name
|
||||||
|
|
||||||
|
def append_data(self, new_data):
|
||||||
|
# 假设 new_data 是一维数组,将其转换为字典
|
||||||
|
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
|
||||||
|
new_dict = dict(zip(columns, new_data))
|
||||||
|
# 使用 loc 方法追加数据
|
||||||
|
self.df.loc[len(self.df)] = new_dict
|
||||||
|
def print_analysis_result(self):
|
||||||
|
# 测试样例总数
|
||||||
|
total_samples = len(self.df)
|
||||||
|
|
||||||
|
# 解析成功数量
|
||||||
|
parse_success_count = (self.df['解析状态'] == '解析成功').sum()
|
||||||
|
|
||||||
|
# 执行成功数量
|
||||||
|
execute_success_count = (self.df['执行状态'] == '执行成功').sum()
|
||||||
|
|
||||||
|
# 解析平均耗时,保留两位小数
|
||||||
|
avg_parse_time = round(self.df['解析耗时'].mean(), 2)
|
||||||
|
|
||||||
|
# 执行平均耗时,保留两位小数
|
||||||
|
avg_execute_time = round(self.df['执行耗时'].mean(), 2)
|
||||||
|
|
||||||
|
# 总平均耗时,保留两位小数
|
||||||
|
avg_total_time = round(self.df['总耗时'].mean(), 2)
|
||||||
|
|
||||||
|
# 最长耗时,保留两位小数
|
||||||
|
max_time = round(self.df['总耗时'].max(), 2)
|
||||||
|
|
||||||
|
# 最短耗时,保留两位小数
|
||||||
|
min_time = round(self.df['总耗时'].min(), 2)
|
||||||
|
|
||||||
|
print(f"测试样例总数 : {total_samples}")
|
||||||
|
print(f"解析成功数量 : {parse_success_count}")
|
||||||
|
print(f"执行成功数量 : {execute_success_count}")
|
||||||
|
print(f"解析平均耗时 : {avg_parse_time} 秒")
|
||||||
|
print(f"执行平均耗时 : {avg_execute_time} 秒")
|
||||||
|
print(f"总平均耗时 : {avg_total_time} 秒")
|
||||||
|
print(f"最长耗时 : {max_time} 秒")
|
||||||
|
print(f"最短耗时 : {min_time} 秒")
|
||||||
|
|
||||||
|
def write_to_csv(self):
|
||||||
|
# 检查 data 文件夹是否存在,如果不存在则创建
|
||||||
|
if not os.path.exists('res'):
|
||||||
|
os.makedirs('res')
|
||||||
|
# 获取当前时间戳
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||||
|
# 生成带时间戳的文件名
|
||||||
|
file_path = os.path.join('res', f'{self.file_name}_{timestamp}.csv')
|
||||||
|
self.df.to_csv(file_path, index=False)
|
||||||
|
print(f"测试结果已保存到 {file_path}")
|
||||||
|
|
||||||
class BatchTest:
|
class BatchTest:
|
||||||
def __init__(self, url, agentId, chatId, userName):
|
def __init__(self, url, agentId, chatId, userName):
|
||||||
@@ -70,18 +132,35 @@ class BatchTest:
|
|||||||
def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str):
|
def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str):
|
||||||
batch_test = BatchTest(url, agentId, chatId, userName)
|
batch_test = BatchTest(url, agentId, chatId, userName)
|
||||||
df = batch_test.read_question_from_csv(filePath)
|
df = batch_test.read_question_from_csv(filePath)
|
||||||
|
appender = DataFrameAppender(os.path.basename(filePath))
|
||||||
for index, row in df.iterrows():
|
for index, row in df.iterrows():
|
||||||
question = row['question']
|
question = row['question']
|
||||||
print('start to ask question:', question)
|
print('start to ask question:', question)
|
||||||
# 捕获异常,防止程序中断
|
# 捕获异常,防止程序中断
|
||||||
try:
|
try:
|
||||||
parse_resp = batch_test.parse(question)
|
parse_resp = batch_test.parse(question)
|
||||||
batch_test.execute(agentId, question, parse_resp['data']['queryId'])
|
parse_status = '解析失败'
|
||||||
|
if parse_resp.get('data').get('errorMsg') is None:
|
||||||
|
parse_status = '解析成功'
|
||||||
|
parse_cost = parse_resp.get('data').get('parseTimeCost').get('parseTime')
|
||||||
|
execute_resp = batch_test.execute(agentId, question, parse_resp['data']['queryId'])
|
||||||
|
execute_status = '执行失败'
|
||||||
|
execute_cost = 0
|
||||||
|
if parse_status == '解析成功' and execute_resp.get('data').get('errorMsg') is None:
|
||||||
|
execute_status = '执行成功'
|
||||||
|
execute_cost = execute_resp.get('data').get('queryTimeCost')
|
||||||
|
res = [question.replace(',', '#'),parse_status,parse_cost/1000,execute_status,execute_cost/1000,(parse_cost+execute_cost)/1000]
|
||||||
|
appender.append_data(res)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print('error:', e)
|
print('error:', e)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
# 打印分析结果
|
||||||
|
appender.print_analysis_result()
|
||||||
|
# 分析明细输出
|
||||||
|
appender.write_to_csv()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user