dataset-test003/client.py
2025-09-19 17:49:06 +08:00

252 lines
8.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import asyncio
import json
import os
import time
import websockets
import requests
from urllib.parse import quote
class AudioClient:
def __init__(self, server_uri: str, headers):
self.server_uri = server_uri # 服务端WebSocket地址
self.headers = headers
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
self.start_time: float = 0 # 记录开始时间
self.is_connected = False # 连接状态标记
async def connect(self):
"""连接到WebSocket服务端"""
try:
self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers)
self.is_connected = True
print(f"已连接到服务端: {self.server_uri}")
return True
except Exception as e:
print(f"连接失败: {str(e)}")
self.is_connected = False
return False
async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024):
"""发送音频文件元数据(名称、大小、分片数等)"""
if not self.websocket or not self.is_connected:
print("连接未建立或已关闭,尝试重新连接...")
if not await self.connect():
return False
print("开始发送请求..........")
# 记录开始时间
self.start_time = time.time()
# 音频文件信息
filename = os.path.basename(audio_path)
file_size = os.path.getsize(audio_path)
total_chunks = (file_size + chunk_size - 1) // chunk_size # 计算总分片数
meta = {
"type": "metadata",
"filename": filename,
"total_size": file_size,
"total_chunks": total_chunks,
"format": audio_path.split(".")[-1],
}
print(f"发送文件源信息: {meta}")
try:
await self.websocket.send(json.dumps(meta))
print("已发送元数据,等待确认...")
# 等待元数据确认
ack = await self.websocket.recv()
ack_data = json.loads(ack)
if ack_data["type"] != "ack":
raise Exception(f"元数据发送失败: {ack_data}")
print(f"开始发送文件分片: {filename}")
# 分片发送音频数据
with open(audio_path, "rb") as f:
for chunk_id in range(total_chunks):
# 读取分片数据
chunk_data = f.read(chunk_size)
is_last = (chunk_id == total_chunks - 1)
# 发送分片Hex编码
chunk_msg = {
"type": "chunk",
"chunk_id": chunk_id,
"is_last": is_last,
"data": chunk_data.hex()
}
await self.websocket.send(json.dumps(chunk_msg))
print(f"发送分片: {chunk_id}, 等待服务器接收")
# 等待分片确认
chunk_ack = await self.websocket.recv()
chunk_ack_data = json.loads(chunk_ack)
if chunk_ack_data["type"] != "ack":
raise Exception(f"分片 {chunk_id} 发送失败")
print(f"分片: {chunk_id} 接收成功")
# 显示进度
progress = (chunk_id + 1) / total_chunks * 100
print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r")
print("\n所有分片发送完成,等待识别结果...")
# 循环监听,直到得到最终消息
is_listen = True
while is_listen:
# 3. 获取识别结果
result = await self.websocket.recv()
result_data = json.loads(result)
print(f"识别结果: {result_data}")
# 获取最终消息
if result_data["type"] == "final_result":
is_listen = False
# 服务端返回异常
if result_data["type"] == "error":
raise Exception(f"识别失败: {result_data['message']}")
total_time = time.time() - self.start_time
print(f"\n识别完成 (总耗时: {total_time:.2f}秒)")
except Exception as e:
print(f"发送过程出错: {str(e)}")
self.is_connected = False
return False
async def receive_messages(self):
"""持续接收服务端的响应消息"""
if not self.websocket or self.is_connected:
return
try:
while self.is_connected:
response = await self.websocket.recv()
message = json.loads(response)
print(f"服务端消息: {message}")
except websockets.exceptions.ConnectionClosed:
print("\n与服务端的连接已关闭")
self.is_connected = False
except Exception as e:
print(f"\n接收消息出错: {str(e)}")
self.is_connected = False
async def close(self):
"""关闭WebSocket连接"""
if self.websocket and self.is_connected:
self.is_connected = False
await self.websocket.close()
print("已关闭连接")
async def main(args):
encoded_model = quote(args.model)
gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}" # HTTP基础地址用于健康检查
gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}" # WebSocket地址网关路由
print(f"gateway_ws_url: {gateway_ws_url}")
# #http请求: 服务健康检查
# health_check_url = f"{gateway_http_url}/health"
# try:
# response = requests.get(health_check_url, timeout=10)
# print("健康检查状态码:", response.status_code)
# print("健康检查响应内容:", response.json())
# except Exception as e:
# print(f"健康检查失败: {str(e)}")
# return # 健康检查失败则退出
# #http请求: 获取模型信息
# model_url = f"{gateway_http_url}/model"
# try:
# response = requests.get(model_url, timeout=10)
# print("模型信息状态码:", response.status_code)
# print("模型信息响应内容:", response.json())
# except Exception as e:
# print(f"获取模型信息失败: {str(e)}")
# return # 获取模型信息失败则退出
# 构造请求头,携带 Bearer 令牌
header = {
"Authorization": f"Bearer {args.token}" # 核心:添加认证头
}
# 初始化客户端并连接
client = AudioClient(server_uri=gateway_ws_url, headers=header)
if not await client.connect():
return # 连接失败则退出
try:
# 启动一个任务持续接收消息(非阻塞)
receive_task = asyncio.create_task(client.receive_messages())
#发送音频文件
file_path=args.audio_file
# 检查文件是否存在
if not os.path.exists(file_path):
print(f"错误: 音频文件不存在 - {file_path}")
return
await client.send_audio(audio_path=file_path)
# 等待所有任务完成
await receive_task
except Exception as e:
print(f"发生错误: {str(e)}")
finally:
# 关闭连接
await client.close()
if __name__ == "__main__":
# 配置命令行参数解析
parser = argparse.ArgumentParser(
description="网关语音识别WebSocket测试脚本",
formatter_class=argparse.RawTextHelpFormatter
)
# 必传参数(网关连接+模型+音频文件)
required_args = parser.add_argument_group("必传参数")
required_args.add_argument(
"--gateway-host",
required=True,
help="网关IP或域名"
)
required_args.add_argument(
"--gateway-port",
required=True,
type=int,
help="网关端口"
)
required_args.add_argument(
"--model",
required=True,
help="语音识别模型ID"
)
required_args.add_argument(
"--audio-file",
required=True,
help="本地音频文件路径"
)
required_args.add_argument(
"--token",
required=True,
help="token"
)
# 解析参数并启动异步任务
args = parser.parse_args()
# 运行客户端
asyncio.run(main(args))