diff --git a/client.py b/client.py new file mode 100644 index 0000000..d89b6fe --- /dev/null +++ b/client.py @@ -0,0 +1,251 @@ +import argparse +import asyncio +import json +import os +import time +import websockets +import requests +from urllib.parse import quote + +class AudioClient: + def __init__(self, server_uri: str, headers): + self.server_uri = server_uri # 服务端WebSocket地址 + self.headers = headers + self.websocket: Optional[websockets.WebSocketClientProtocol] = None + self.start_time: float = 0 # 记录开始时间 + self.is_connected = False # 连接状态标记 + + async def connect(self): + """连接到WebSocket服务端""" + try: + self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers) + self.is_connected = True + print(f"已连接到服务端: {self.server_uri}") + return True + except Exception as e: + print(f"连接失败: {str(e)}") + self.is_connected = False + return False + + async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024): + """发送音频文件元数据(名称、大小、分片数等)""" + if not self.websocket or not self.is_connected: + print("连接未建立或已关闭,尝试重新连接...") + if not await self.connect(): + return False + + print("开始发送请求..........") + + # 记录开始时间 + self.start_time = time.time() + + # 音频文件信息 + filename = os.path.basename(audio_path) + file_size = os.path.getsize(audio_path) + total_chunks = (file_size + chunk_size - 1) // chunk_size # 计算总分片数 + + meta = { + "type": "metadata", + "filename": filename, + "total_size": file_size, + "total_chunks": total_chunks, + "format": audio_path.split(".")[-1], + } + + print(f"发送文件源信息: {meta}") + + try: + await self.websocket.send(json.dumps(meta)) + print("已发送元数据,等待确认...") + + # 等待元数据确认 + ack = await self.websocket.recv() + ack_data = json.loads(ack) + if ack_data["type"] != "ack": + raise Exception(f"元数据发送失败: {ack_data}") + + print(f"开始发送文件分片: {filename}") + + # 分片发送音频数据 + with open(audio_path, "rb") as f: + for chunk_id in range(total_chunks): + # 读取分片数据 + chunk_data = f.read(chunk_size) + is_last = (chunk_id == total_chunks - 1) + + # 发送分片(Hex编码) + chunk_msg = { + "type": "chunk", + "chunk_id": chunk_id, + "is_last": is_last, + "data": chunk_data.hex() + } + await self.websocket.send(json.dumps(chunk_msg)) + + print(f"发送分片: {chunk_id}, 等待服务器接收") + + # 等待分片确认 + chunk_ack = await self.websocket.recv() + chunk_ack_data = json.loads(chunk_ack) + + if chunk_ack_data["type"] != "ack": + raise Exception(f"分片 {chunk_id} 发送失败") + + print(f"分片: {chunk_id} 接收成功") + + # 显示进度 + progress = (chunk_id + 1) / total_chunks * 100 + print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r") + + print("\n所有分片发送完成,等待识别结果...") + + # 循环监听,直到得到最终消息 + is_listen = True + while is_listen: + # 3. 获取识别结果 + result = await self.websocket.recv() + result_data = json.loads(result) + print(f"识别结果: {result_data}") + + # 获取最终消息 + if result_data["type"] == "final_result": + is_listen = False + + # 服务端返回异常 + if result_data["type"] == "error": + raise Exception(f"识别失败: {result_data['message']}") + + total_time = time.time() - self.start_time + print(f"\n识别完成 (总耗时: {total_time:.2f}秒)") + + except Exception as e: + print(f"发送过程出错: {str(e)}") + self.is_connected = False + return False + + async def receive_messages(self): + """持续接收服务端的响应消息""" + if not self.websocket or self.is_connected: + return + + try: + while self.is_connected: + response = await self.websocket.recv() + message = json.loads(response) + print(f"服务端消息: {message}") + + except websockets.exceptions.ConnectionClosed: + print("\n与服务端的连接已关闭") + self.is_connected = False + except Exception as e: + print(f"\n接收消息出错: {str(e)}") + self.is_connected = False + + async def close(self): + """关闭WebSocket连接""" + if self.websocket and self.is_connected: + self.is_connected = False + await self.websocket.close() + print("已关闭连接") + + + +async def main(args): + encoded_model = quote(args.model) + gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}" # HTTP基础地址(用于健康检查) + gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}" # WebSocket地址(网关路由) + print(f"gateway_ws_url: {gateway_ws_url}") + + # #http请求: 服务健康检查 + # health_check_url = f"{gateway_http_url}/health" + # try: + # response = requests.get(health_check_url, timeout=10) + # print("健康检查状态码:", response.status_code) + # print("健康检查响应内容:", response.json()) + # except Exception as e: + # print(f"健康检查失败: {str(e)}") + # return # 健康检查失败则退出 + + # #http请求: 获取模型信息 + # model_url = f"{gateway_http_url}/model" + # try: + # response = requests.get(model_url, timeout=10) + # print("模型信息状态码:", response.status_code) + # print("模型信息响应内容:", response.json()) + # except Exception as e: + # print(f"获取模型信息失败: {str(e)}") + # return # 获取模型信息失败则退出 + + # 构造请求头,携带 Bearer 令牌 + header = { + "Authorization": f"Bearer {args.token}" # 核心:添加认证头 + } + # 初始化客户端并连接 + client = AudioClient(server_uri=gateway_ws_url, headers=header) + if not await client.connect(): + return # 连接失败则退出 + + try: + # 启动一个任务持续接收消息(非阻塞) + receive_task = asyncio.create_task(client.receive_messages()) + + #发送音频文件 + file_path=args.audio_file + # 检查文件是否存在 + if not os.path.exists(file_path): + print(f"错误: 音频文件不存在 - {file_path}") + return + await client.send_audio(audio_path=file_path) + + # 等待所有任务完成 + await receive_task + + except Exception as e: + print(f"发生错误: {str(e)}") + + finally: + # 关闭连接 + await client.close() + +if __name__ == "__main__": + # 配置命令行参数解析 + parser = argparse.ArgumentParser( + description="网关语音识别WebSocket测试脚本", + formatter_class=argparse.RawTextHelpFormatter + ) + + # 必传参数(网关连接+模型+音频文件) + required_args = parser.add_argument_group("必传参数") + required_args.add_argument( + "--gateway-host", + required=True, + help="网关IP或域名" + ) + required_args.add_argument( + "--gateway-port", + required=True, + type=int, + help="网关端口" + ) + required_args.add_argument( + "--model", + required=True, + help="语音识别模型ID" + ) + required_args.add_argument( + "--audio-file", + required=True, + help="本地音频文件路径" + ) + required_args.add_argument( + "--token", + required=True, + help="token" + ) + + # 解析参数并启动异步任务 + args = parser.parse_args() + + # 运行客户端 + asyncio.run(main(args)) +