Files
supersonic/benchmark/benchmark.py
Antgeek 97710a90c4
Some checks failed
supersonic CentOS CI / build (21) (push) Has been cancelled
supersonic mac CI / build (21) (push) Has been cancelled
supersonic ubuntu CI / build (21) (push) Has been cancelled
supersonic windows CI / build (21) (push) Has been cancelled
(improve)(benchmark) improve benchmark, add analysis of parsing results (#2215)
2025-04-10 08:57:25 +08:00

184 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# -----------------------------------------------------------------------------------
'''
@filename : batchmark.py
@time : 2024/06/20
@author : zhaodongsheng
@Version : 1.0
@description : 批量问答测试
'''
# -----------------------------------------------------------------------------------
import pandas as pd
import json
import requests
import time
import jwt
import traceback
import os
from datetime import datetime
class DataFrameAppender:
def __init__(self,file_name = "output"):
# 定义表头
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
# 创建只有表头的 DataFrame
self.df = pd.DataFrame(columns=columns)
self.file_name = file_name
def append_data(self, new_data):
# 假设 new_data 是一维数组,将其转换为字典
columns = ['问题', '解析状态', '解析耗时', '执行状态', '执行耗时', '总耗时']
new_dict = dict(zip(columns, new_data))
# 使用 loc 方法追加数据
self.df.loc[len(self.df)] = new_dict
def print_analysis_result(self):
# 测试样例总数
total_samples = len(self.df)
# 解析成功数量
parse_success_count = (self.df['解析状态'] == '解析成功').sum()
# 执行成功数量
execute_success_count = (self.df['执行状态'] == '执行成功').sum()
# 解析平均耗时,保留两位小数
avg_parse_time = round(self.df['解析耗时'].mean(), 2)
# 执行平均耗时,保留两位小数
avg_execute_time = round(self.df['执行耗时'].mean(), 2)
# 总平均耗时,保留两位小数
avg_total_time = round(self.df['总耗时'].mean(), 2)
# 最长耗时,保留两位小数
max_time = round(self.df['总耗时'].max(), 2)
# 最短耗时,保留两位小数
min_time = round(self.df['总耗时'].min(), 2)
print(f"测试样例总数 : {total_samples}")
print(f"解析成功数量 : {parse_success_count}")
print(f"执行成功数量 : {execute_success_count}")
print(f"解析平均耗时 : {avg_parse_time}")
print(f"执行平均耗时 : {avg_execute_time}")
print(f"总平均耗时 : {avg_total_time}")
print(f"最长耗时 : {max_time}")
print(f"最短耗时 : {min_time}")
def write_to_csv(self):
# 检查 data 文件夹是否存在,如果不存在则创建
if not os.path.exists('res'):
os.makedirs('res')
# 获取当前时间戳
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
# 生成带时间戳的文件名
file_path = os.path.join('res', f'{self.file_name}_{timestamp}.csv')
self.df.to_csv(file_path, index=False)
print(f"测试结果已保存到 {file_path}")
class BatchTest:
def __init__(self, url, agentId, chatId, userName):
self.base_url = url + '/api/chat/query/'
self.agentId = agentId
self.auth_token = self.__get_authorization(userName)
self.chatId = chatId
def parse(self, query_text):
url = self.base_url + 'parse'
data = {
'queryText': query_text,
'agentId': self.agentId,
'chatId': self.chatId,
}
headers = {
'Authorization': 'Bearer ' + self.auth_token,
'Content-Type': 'application/json',
}
response = requests.post(url, headers=headers, data=json.dumps(data))
return response.json()
def execute(self, agentId, query_text, queryId):
url = self.base_url + 'execute'
data = {
'agentId': agentId,
'queryText': query_text,
'parseId': 1,
'chatId': self.chatId,
'queryId': queryId,
}
headers = {
'Authorization': 'Bearer ' + self.auth_token,
'Content-Type': 'application/json',
}
response = requests.post(url, headers=headers, data=json.dumps(data))
return response.json()
def read_question_from_csv(self, filePath):
df = pd.read_csv(filePath)
return df
def __get_authorization(self, userName):
# secret 请和 com.tencent.supersonic.auth.api.authentication.config.AuthenticationConfig.tokenAppSecret 保持一致
secret = "WIaO9YRRVt+7QtpPvyWsARFngnEcbaKBk783uGFwMrbJBaochsqCH62L4Kijcb0sZCYoSsiKGV/zPml5MnZ3uQ=="
exp = time.time() + 100000000
token= jwt.encode({"token_user_name": userName,"exp": exp}, secret, algorithm="HS512")
return token
def benchmark(url:str, agentId:str, chatId:str, filePath:str, userName:str):
batch_test = BatchTest(url, agentId, chatId, userName)
df = batch_test.read_question_from_csv(filePath)
appender = DataFrameAppender(os.path.basename(filePath))
for index, row in df.iterrows():
question = row['question']
print('start to ask question:', question)
# 捕获异常,防止程序中断
try:
parse_resp = batch_test.parse(question)
parse_status = '解析失败'
if parse_resp.get('data').get('errorMsg') is None:
parse_status = '解析成功'
parse_cost = parse_resp.get('data').get('parseTimeCost').get('parseTime')
execute_resp = batch_test.execute(agentId, question, parse_resp['data']['queryId'])
execute_status = '执行失败'
execute_cost = 0
if parse_status == '解析成功' and execute_resp.get('data').get('errorMsg') is None:
execute_status = '执行成功'
execute_cost = execute_resp.get('data').get('queryTimeCost')
res = [question.replace(',', '#'),parse_status,parse_cost/1000,execute_status,execute_cost/1000,(parse_cost+execute_cost)/1000]
appender.append_data(res)
except Exception as e:
print('error:', e)
traceback.print_exc()
continue
time.sleep(1)
# 打印分析结果
appender.print_analysis_result()
# 分析明细输出
appender.write_to_csv()
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-u', '--url', type=str, required=True, help='url:问答系统url,例如https://chatdata-dev.test.com')
parser.add_argument('-a', '--agentId', type=str, required=True, help='agentId助手ID')
parser.add_argument('-c', '--chatId', type=str, required=True, help='chatId:会话ID,需要通过浏览器开发者模式获取')
parser.add_argument('-f', '--filePath', type=str, required=True, help='filePath问题文件路径, csv格式. 请提前上传到benchmark/data目录下')
parser.add_argument('-p', '--userName', type=str, required=True, help='userName用户名用户获取登录token')
args = parser.parse_args()
print('批量测试配置信息[url:', args.url,'agentId:', args.agentId, 'chatId:', args.chatId, 'filePath:', args.filePath, 'userName:', args.userName, ']')
print('请确认输入的压力测试信息是否正确:')
print('1. Yes')
print('2. No')
confirm = input()
if confirm == '1' or confirm == 'Yes' or confirm == 'yes' or confirm == 'YES':
benchmark(args.url, args.agentId, args.chatId, args.filePath, args.userName)
else:
print('请重新输入压力测试配置信息: url, agentId, chatId, filePath, userName')