import argparse import asyncio import json import os import time import websockets import requests from urllib.parse import quote class AudioClient: def __init__(self, server_uri: str, headers): self.server_uri = server_uri # 服务端WebSocket地址 self.headers = headers self.websocket: Optional[websockets.WebSocketClientProtocol] = None self.start_time: float = 0 # 记录开始时间 self.is_connected = False # 连接状态标记 async def connect(self): """连接到WebSocket服务端""" try: self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers) self.is_connected = True print(f"已连接到服务端: {self.server_uri}") return True except Exception as e: print(f"连接失败: {str(e)}") self.is_connected = False return False async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024): """发送音频文件元数据(名称、大小、分片数等)""" if not self.websocket or not self.is_connected: print("连接未建立或已关闭,尝试重新连接...") if not await self.connect(): return False print("开始发送请求..........") # 记录开始时间 self.start_time = time.time() # 音频文件信息 filename = os.path.basename(audio_path) file_size = os.path.getsize(audio_path) total_chunks = (file_size + chunk_size - 1) // chunk_size # 计算总分片数 meta = { "type": "metadata", "filename": filename, "total_size": file_size, "total_chunks": total_chunks, "format": audio_path.split(".")[-1], } print(f"发送文件源信息: {meta}") try: await self.websocket.send(json.dumps(meta)) print("已发送元数据,等待确认...") # 等待元数据确认 ack = await self.websocket.recv() ack_data = json.loads(ack) if ack_data["type"] != "ack": raise Exception(f"元数据发送失败: {ack_data}") print(f"开始发送文件分片: {filename}") # 分片发送音频数据 with open(audio_path, "rb") as f: for chunk_id in range(total_chunks): # 读取分片数据 chunk_data = f.read(chunk_size) is_last = (chunk_id == total_chunks - 1) # 发送分片(Hex编码) chunk_msg = { "type": "chunk", "chunk_id": chunk_id, "is_last": is_last, "data": chunk_data.hex() } await self.websocket.send(json.dumps(chunk_msg)) print(f"发送分片: {chunk_id}, 等待服务器接收") # 等待分片确认 chunk_ack = await self.websocket.recv() chunk_ack_data = json.loads(chunk_ack) if chunk_ack_data["type"] != "ack": raise Exception(f"分片 {chunk_id} 发送失败") print(f"分片: {chunk_id} 接收成功") # 显示进度 progress = (chunk_id + 1) / total_chunks * 100 print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r") print("\n所有分片发送完成,等待识别结果...") # 循环监听,直到得到最终消息 is_listen = True while is_listen: # 3. 获取识别结果 result = await self.websocket.recv() result_data = json.loads(result) print(f"识别结果: {result_data}") # 获取最终消息 if result_data["type"] == "final_result": is_listen = False # 服务端返回异常 if result_data["type"] == "error": raise Exception(f"识别失败: {result_data['message']}") total_time = time.time() - self.start_time print(f"\n识别完成 (总耗时: {total_time:.2f}秒)") except Exception as e: print(f"发送过程出错: {str(e)}") self.is_connected = False return False async def receive_messages(self): """持续接收服务端的响应消息""" if not self.websocket or self.is_connected: return try: while self.is_connected: response = await self.websocket.recv() message = json.loads(response) print(f"服务端消息: {message}") except websockets.exceptions.ConnectionClosed: print("\n与服务端的连接已关闭") self.is_connected = False except Exception as e: print(f"\n接收消息出错: {str(e)}") self.is_connected = False async def close(self): """关闭WebSocket连接""" if self.websocket and self.is_connected: self.is_connected = False await self.websocket.close() print("已关闭连接") async def main(args): encoded_model = quote(args.model) gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}" # HTTP基础地址(用于健康检查) gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}" # WebSocket地址(网关路由) print(f"gateway_ws_url: {gateway_ws_url}") # #http请求: 服务健康检查 # health_check_url = f"{gateway_http_url}/health" # try: # response = requests.get(health_check_url, timeout=10) # print("健康检查状态码:", response.status_code) # print("健康检查响应内容:", response.json()) # except Exception as e: # print(f"健康检查失败: {str(e)}") # return # 健康检查失败则退出 # #http请求: 获取模型信息 # model_url = f"{gateway_http_url}/model" # try: # response = requests.get(model_url, timeout=10) # print("模型信息状态码:", response.status_code) # print("模型信息响应内容:", response.json()) # except Exception as e: # print(f"获取模型信息失败: {str(e)}") # return # 获取模型信息失败则退出 # 构造请求头,携带 Bearer 令牌 header = { "Authorization": f"Bearer {args.token}" # 核心:添加认证头 } # 初始化客户端并连接 client = AudioClient(server_uri=gateway_ws_url, headers=header) if not await client.connect(): return # 连接失败则退出 try: # 启动一个任务持续接收消息(非阻塞) receive_task = asyncio.create_task(client.receive_messages()) #发送音频文件 file_path=args.audio_file # 检查文件是否存在 if not os.path.exists(file_path): print(f"错误: 音频文件不存在 - {file_path}") return await client.send_audio(audio_path=file_path) # 等待所有任务完成 await receive_task except Exception as e: print(f"发生错误: {str(e)}") finally: # 关闭连接 await client.close() if __name__ == "__main__": # 配置命令行参数解析 parser = argparse.ArgumentParser( description="网关语音识别WebSocket测试脚本", formatter_class=argparse.RawTextHelpFormatter ) # 必传参数(网关连接+模型+音频文件) required_args = parser.add_argument_group("必传参数") required_args.add_argument( "--gateway-host", required=True, help="网关IP或域名" ) required_args.add_argument( "--gateway-port", required=True, type=int, help="网关端口" ) required_args.add_argument( "--model", required=True, help="语音识别模型ID" ) required_args.add_argument( "--audio-file", required=True, help="本地音频文件路径" ) required_args.add_argument( "--token", required=True, help="token" ) # 解析参数并启动异步任务 args = parser.parse_args() # 运行客户端 asyncio.run(main(args))