252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			252 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | ||
| import asyncio
 | ||
| import json
 | ||
| import os
 | ||
| import time
 | ||
| import websockets
 | ||
| import requests
 | ||
| from urllib.parse import quote
 | ||
| 
 | ||
| class AudioClient:
 | ||
|     def __init__(self, server_uri: str, headers):
 | ||
|         self.server_uri = server_uri  # 服务端WebSocket地址
 | ||
|         self.headers = headers
 | ||
|         self.websocket: Optional[websockets.WebSocketClientProtocol] = None
 | ||
|         self.start_time: float = 0  # 记录开始时间
 | ||
|         self.is_connected = False  # 连接状态标记
 | ||
| 
 | ||
|     async def connect(self):
 | ||
|         """连接到WebSocket服务端"""
 | ||
|         try:
 | ||
|             self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers)
 | ||
|             self.is_connected = True
 | ||
|             print(f"已连接到服务端: {self.server_uri}")
 | ||
|             return True
 | ||
|         except Exception as e:
 | ||
|             print(f"连接失败: {str(e)}")
 | ||
|             self.is_connected = False
 | ||
|             return False
 | ||
|         
 | ||
|     async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024):
 | ||
|         """发送音频文件元数据(名称、大小、分片数等)"""
 | ||
|         if not self.websocket or not self.is_connected:
 | ||
|             print("连接未建立或已关闭,尝试重新连接...")
 | ||
|             if not await self.connect():
 | ||
|                 return False
 | ||
| 
 | ||
|         print("开始发送请求..........")
 | ||
| 
 | ||
|         # 记录开始时间
 | ||
|         self.start_time = time.time()
 | ||
| 
 | ||
|         # 音频文件信息
 | ||
|         filename = os.path.basename(audio_path)
 | ||
|         file_size = os.path.getsize(audio_path)
 | ||
|         total_chunks = (file_size + chunk_size - 1) // chunk_size  # 计算总分片数
 | ||
| 
 | ||
|         meta = {
 | ||
|             "type": "metadata",
 | ||
|             "filename": filename,
 | ||
|             "total_size": file_size,
 | ||
|             "total_chunks": total_chunks,
 | ||
|             "format": audio_path.split(".")[-1],
 | ||
|         }
 | ||
| 
 | ||
|         print(f"发送文件源信息: {meta}")
 | ||
| 
 | ||
|         try:
 | ||
|             await self.websocket.send(json.dumps(meta))
 | ||
|             print("已发送元数据,等待确认...")
 | ||
|                 
 | ||
|             # 等待元数据确认
 | ||
|             ack = await self.websocket.recv()
 | ||
|             ack_data = json.loads(ack)
 | ||
|             if ack_data["type"] != "ack":
 | ||
|                 raise Exception(f"元数据发送失败: {ack_data}")
 | ||
|             
 | ||
|             print(f"开始发送文件分片: {filename}")
 | ||
| 
 | ||
|             # 分片发送音频数据
 | ||
|             with open(audio_path, "rb") as f:
 | ||
|                 for chunk_id in range(total_chunks):
 | ||
|                     # 读取分片数据
 | ||
|                     chunk_data = f.read(chunk_size)
 | ||
|                     is_last = (chunk_id == total_chunks - 1)
 | ||
|                         
 | ||
|                     # 发送分片(Hex编码)
 | ||
|                     chunk_msg = {
 | ||
|                         "type": "chunk",
 | ||
|                         "chunk_id": chunk_id,
 | ||
|                         "is_last": is_last,
 | ||
|                         "data": chunk_data.hex()
 | ||
|                     }
 | ||
|                     await self.websocket.send(json.dumps(chunk_msg))
 | ||
| 
 | ||
|                     print(f"发送分片: {chunk_id}, 等待服务器接收")
 | ||
|                         
 | ||
|                     # 等待分片确认
 | ||
|                     chunk_ack = await self.websocket.recv()
 | ||
|                     chunk_ack_data = json.loads(chunk_ack)
 | ||
|                         
 | ||
|                     if chunk_ack_data["type"] != "ack":
 | ||
|                         raise Exception(f"分片 {chunk_id} 发送失败")
 | ||
| 
 | ||
|                     print(f"分片: {chunk_id} 接收成功")
 | ||
|                         
 | ||
|                     # 显示进度
 | ||
|                     progress = (chunk_id + 1) / total_chunks * 100
 | ||
|                     print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r")
 | ||
| 
 | ||
|             print("\n所有分片发送完成,等待识别结果...")
 | ||
| 
 | ||
|             # 循环监听,直到得到最终消息
 | ||
|             is_listen = True
 | ||
|             while is_listen:
 | ||
|                 # 3. 获取识别结果
 | ||
|                 result = await self.websocket.recv()
 | ||
|                 result_data = json.loads(result)
 | ||
|                 print(f"识别结果: {result_data}")
 | ||
| 
 | ||
|                 # 获取最终消息
 | ||
|                 if result_data["type"] == "final_result":
 | ||
|                     is_listen = False
 | ||
| 
 | ||
|                 # 服务端返回异常   
 | ||
|                 if result_data["type"] == "error":
 | ||
|                     raise Exception(f"识别失败: {result_data['message']}")
 | ||
| 
 | ||
|             total_time = time.time() - self.start_time
 | ||
|             print(f"\n识别完成 (总耗时: {total_time:.2f}秒)")
 | ||
|         
 | ||
|         except Exception as e:
 | ||
|             print(f"发送过程出错: {str(e)}")
 | ||
|             self.is_connected = False
 | ||
|             return False
 | ||
| 
 | ||
|     async def receive_messages(self):
 | ||
|         """持续接收服务端的响应消息"""
 | ||
|         if not self.websocket or self.is_connected:
 | ||
|             return
 | ||
|         
 | ||
|         try:
 | ||
|             while self.is_connected:
 | ||
|                 response = await self.websocket.recv()
 | ||
|                 message = json.loads(response)
 | ||
|                 print(f"服务端消息: {message}")
 | ||
|                 
 | ||
|         except websockets.exceptions.ConnectionClosed:
 | ||
|             print("\n与服务端的连接已关闭")
 | ||
|             self.is_connected = False
 | ||
|         except Exception as e:
 | ||
|             print(f"\n接收消息出错: {str(e)}")
 | ||
|             self.is_connected = False
 | ||
| 
 | ||
|     async def close(self):
 | ||
|         """关闭WebSocket连接"""
 | ||
|         if self.websocket and self.is_connected:
 | ||
|             self.is_connected = False
 | ||
|             await self.websocket.close()
 | ||
|             print("已关闭连接")
 | ||
| 
 | ||
| 
 | ||
| 
 | ||
| async def main(args):
 | ||
|     encoded_model = quote(args.model)
 | ||
|     gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}"  # HTTP基础地址(用于健康检查)
 | ||
|     gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}"  # WebSocket地址(网关路由) 
 | ||
|     print(f"gateway_ws_url: {gateway_ws_url}")
 | ||
| 
 | ||
|     # #http请求: 服务健康检查
 | ||
|     # health_check_url = f"{gateway_http_url}/health"
 | ||
|     # try:
 | ||
|     #     response = requests.get(health_check_url, timeout=10)
 | ||
|     #     print("健康检查状态码:", response.status_code)  
 | ||
|     #     print("健康检查响应内容:", response.json())
 | ||
|     # except Exception as e:
 | ||
|     #     print(f"健康检查失败: {str(e)}")
 | ||
|     #     return  # 健康检查失败则退出
 | ||
| 
 | ||
|     # #http请求: 获取模型信息
 | ||
|     # model_url = f"{gateway_http_url}/model"
 | ||
|     # try:
 | ||
|     #     response = requests.get(model_url, timeout=10)
 | ||
|     #     print("模型信息状态码:", response.status_code)  
 | ||
|     #     print("模型信息响应内容:", response.json())
 | ||
|     # except Exception as e:
 | ||
|     #     print(f"获取模型信息失败: {str(e)}")
 | ||
|     #     return  # 获取模型信息失败则退出
 | ||
|     
 | ||
|     # 构造请求头,携带 Bearer 令牌
 | ||
|     header = {
 | ||
|         "Authorization": f"Bearer {args.token}"  # 核心:添加认证头
 | ||
|     }
 | ||
|     # 初始化客户端并连接
 | ||
|     client = AudioClient(server_uri=gateway_ws_url, headers=header)
 | ||
|     if not await client.connect():
 | ||
|         return  # 连接失败则退出
 | ||
|     
 | ||
|     try:
 | ||
|         # 启动一个任务持续接收消息(非阻塞)
 | ||
|         receive_task = asyncio.create_task(client.receive_messages())
 | ||
| 
 | ||
|         #发送音频文件
 | ||
|         file_path=args.audio_file
 | ||
|         # 检查文件是否存在
 | ||
|         if not os.path.exists(file_path):
 | ||
|             print(f"错误: 音频文件不存在 - {file_path}")
 | ||
|             return
 | ||
|         await client.send_audio(audio_path=file_path)
 | ||
|         
 | ||
|         # 等待所有任务完成
 | ||
|         await receive_task
 | ||
|     
 | ||
|     except Exception as e:
 | ||
|         print(f"发生错误: {str(e)}")
 | ||
|         
 | ||
|     finally:
 | ||
|         # 关闭连接
 | ||
|         await client.close()
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     # 配置命令行参数解析
 | ||
|     parser = argparse.ArgumentParser(
 | ||
|         description="网关语音识别WebSocket测试脚本",
 | ||
|         formatter_class=argparse.RawTextHelpFormatter
 | ||
|     )
 | ||
| 
 | ||
|     # 必传参数(网关连接+模型+音频文件)
 | ||
|     required_args = parser.add_argument_group("必传参数")
 | ||
|     required_args.add_argument(
 | ||
|         "--gateway-host",
 | ||
|         required=True,
 | ||
|         help="网关IP或域名"
 | ||
|     )
 | ||
|     required_args.add_argument(
 | ||
|         "--gateway-port",
 | ||
|         required=True,
 | ||
|         type=int,
 | ||
|         help="网关端口"
 | ||
|     )
 | ||
|     required_args.add_argument(
 | ||
|         "--model",
 | ||
|         required=True,
 | ||
|         help="语音识别模型ID"
 | ||
|     )
 | ||
|     required_args.add_argument(
 | ||
|         "--audio-file",
 | ||
|         required=True,
 | ||
|         help="本地音频文件路径"
 | ||
|     )
 | ||
|     required_args.add_argument(
 | ||
|         "--token",
 | ||
|         required=True,
 | ||
|         help="token"
 | ||
|     )
 | ||
| 
 | ||
|     # 解析参数并启动异步任务
 | ||
|     args = parser.parse_args()
 | ||
| 
 | ||
|     # 运行客户端
 | ||
|     asyncio.run(main(args))
 | ||
|     
 |