This commit is contained in:
4pdadmin 2025-09-19 17:49:06 +08:00
parent 67a07b1675
commit f747aaacd4

251
client.py Normal file
View File

@ -0,0 +1,251 @@
import argparse
import asyncio
import json
import os
import time
import websockets
import requests
from urllib.parse import quote
class AudioClient:
def __init__(self, server_uri: str, headers):
self.server_uri = server_uri # 服务端WebSocket地址
self.headers = headers
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
self.start_time: float = 0 # 记录开始时间
self.is_connected = False # 连接状态标记
async def connect(self):
"""连接到WebSocket服务端"""
try:
self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers)
self.is_connected = True
print(f"已连接到服务端: {self.server_uri}")
return True
except Exception as e:
print(f"连接失败: {str(e)}")
self.is_connected = False
return False
async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024):
"""发送音频文件元数据(名称、大小、分片数等)"""
if not self.websocket or not self.is_connected:
print("连接未建立或已关闭,尝试重新连接...")
if not await self.connect():
return False
print("开始发送请求..........")
# 记录开始时间
self.start_time = time.time()
# 音频文件信息
filename = os.path.basename(audio_path)
file_size = os.path.getsize(audio_path)
total_chunks = (file_size + chunk_size - 1) // chunk_size # 计算总分片数
meta = {
"type": "metadata",
"filename": filename,
"total_size": file_size,
"total_chunks": total_chunks,
"format": audio_path.split(".")[-1],
}
print(f"发送文件源信息: {meta}")
try:
await self.websocket.send(json.dumps(meta))
print("已发送元数据,等待确认...")
# 等待元数据确认
ack = await self.websocket.recv()
ack_data = json.loads(ack)
if ack_data["type"] != "ack":
raise Exception(f"元数据发送失败: {ack_data}")
print(f"开始发送文件分片: {filename}")
# 分片发送音频数据
with open(audio_path, "rb") as f:
for chunk_id in range(total_chunks):
# 读取分片数据
chunk_data = f.read(chunk_size)
is_last = (chunk_id == total_chunks - 1)
# 发送分片Hex编码
chunk_msg = {
"type": "chunk",
"chunk_id": chunk_id,
"is_last": is_last,
"data": chunk_data.hex()
}
await self.websocket.send(json.dumps(chunk_msg))
print(f"发送分片: {chunk_id}, 等待服务器接收")
# 等待分片确认
chunk_ack = await self.websocket.recv()
chunk_ack_data = json.loads(chunk_ack)
if chunk_ack_data["type"] != "ack":
raise Exception(f"分片 {chunk_id} 发送失败")
print(f"分片: {chunk_id} 接收成功")
# 显示进度
progress = (chunk_id + 1) / total_chunks * 100
print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r")
print("\n所有分片发送完成,等待识别结果...")
# 循环监听,直到得到最终消息
is_listen = True
while is_listen:
# 3. 获取识别结果
result = await self.websocket.recv()
result_data = json.loads(result)
print(f"识别结果: {result_data}")
# 获取最终消息
if result_data["type"] == "final_result":
is_listen = False
# 服务端返回异常
if result_data["type"] == "error":
raise Exception(f"识别失败: {result_data['message']}")
total_time = time.time() - self.start_time
print(f"\n识别完成 (总耗时: {total_time:.2f}秒)")
except Exception as e:
print(f"发送过程出错: {str(e)}")
self.is_connected = False
return False
async def receive_messages(self):
"""持续接收服务端的响应消息"""
if not self.websocket or self.is_connected:
return
try:
while self.is_connected:
response = await self.websocket.recv()
message = json.loads(response)
print(f"服务端消息: {message}")
except websockets.exceptions.ConnectionClosed:
print("\n与服务端的连接已关闭")
self.is_connected = False
except Exception as e:
print(f"\n接收消息出错: {str(e)}")
self.is_connected = False
async def close(self):
"""关闭WebSocket连接"""
if self.websocket and self.is_connected:
self.is_connected = False
await self.websocket.close()
print("已关闭连接")
async def main(args):
encoded_model = quote(args.model)
gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}" # HTTP基础地址用于健康检查
gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}" # WebSocket地址网关路由
print(f"gateway_ws_url: {gateway_ws_url}")
# #http请求: 服务健康检查
# health_check_url = f"{gateway_http_url}/health"
# try:
# response = requests.get(health_check_url, timeout=10)
# print("健康检查状态码:", response.status_code)
# print("健康检查响应内容:", response.json())
# except Exception as e:
# print(f"健康检查失败: {str(e)}")
# return # 健康检查失败则退出
# #http请求: 获取模型信息
# model_url = f"{gateway_http_url}/model"
# try:
# response = requests.get(model_url, timeout=10)
# print("模型信息状态码:", response.status_code)
# print("模型信息响应内容:", response.json())
# except Exception as e:
# print(f"获取模型信息失败: {str(e)}")
# return # 获取模型信息失败则退出
# 构造请求头,携带 Bearer 令牌
header = {
"Authorization": f"Bearer {args.token}" # 核心:添加认证头
}
# 初始化客户端并连接
client = AudioClient(server_uri=gateway_ws_url, headers=header)
if not await client.connect():
return # 连接失败则退出
try:
# 启动一个任务持续接收消息(非阻塞)
receive_task = asyncio.create_task(client.receive_messages())
#发送音频文件
file_path=args.audio_file
# 检查文件是否存在
if not os.path.exists(file_path):
print(f"错误: 音频文件不存在 - {file_path}")
return
await client.send_audio(audio_path=file_path)
# 等待所有任务完成
await receive_task
except Exception as e:
print(f"发生错误: {str(e)}")
finally:
# 关闭连接
await client.close()
if __name__ == "__main__":
# 配置命令行参数解析
parser = argparse.ArgumentParser(
description="网关语音识别WebSocket测试脚本",
formatter_class=argparse.RawTextHelpFormatter
)
# 必传参数(网关连接+模型+音频文件)
required_args = parser.add_argument_group("必传参数")
required_args.add_argument(
"--gateway-host",
required=True,
help="网关IP或域名"
)
required_args.add_argument(
"--gateway-port",
required=True,
type=int,
help="网关端口"
)
required_args.add_argument(
"--model",
required=True,
help="语音识别模型ID"
)
required_args.add_argument(
"--audio-file",
required=True,
help="本地音频文件路径"
)
required_args.add_argument(
"--token",
required=True,
help="token"
)
# 解析参数并启动异步任务
args = parser.parse_args()
# 运行客户端
asyncio.run(main(args))