add file
This commit is contained in:
		
							parent
							
								
									67a07b1675
								
							
						
					
					
						commit
						f747aaacd4
					
				
							
								
								
									
										251
									
								
								client.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										251
									
								
								client.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,251 @@ | ||||
| import argparse | ||||
| import asyncio | ||||
| import json | ||||
| import os | ||||
| import time | ||||
| import websockets | ||||
| import requests | ||||
| from urllib.parse import quote | ||||
| 
 | ||||
| class AudioClient: | ||||
|     def __init__(self, server_uri: str, headers): | ||||
|         self.server_uri = server_uri  # 服务端WebSocket地址 | ||||
|         self.headers = headers | ||||
|         self.websocket: Optional[websockets.WebSocketClientProtocol] = None | ||||
|         self.start_time: float = 0  # 记录开始时间 | ||||
|         self.is_connected = False  # 连接状态标记 | ||||
| 
 | ||||
|     async def connect(self): | ||||
|         """连接到WebSocket服务端""" | ||||
|         try: | ||||
|             self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers) | ||||
|             self.is_connected = True | ||||
|             print(f"已连接到服务端: {self.server_uri}") | ||||
|             return True | ||||
|         except Exception as e: | ||||
|             print(f"连接失败: {str(e)}") | ||||
|             self.is_connected = False | ||||
|             return False | ||||
|          | ||||
|     async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024): | ||||
|         """发送音频文件元数据(名称、大小、分片数等)""" | ||||
|         if not self.websocket or not self.is_connected: | ||||
|             print("连接未建立或已关闭,尝试重新连接...") | ||||
|             if not await self.connect(): | ||||
|                 return False | ||||
| 
 | ||||
|         print("开始发送请求..........") | ||||
| 
 | ||||
|         # 记录开始时间 | ||||
|         self.start_time = time.time() | ||||
| 
 | ||||
|         # 音频文件信息 | ||||
|         filename = os.path.basename(audio_path) | ||||
|         file_size = os.path.getsize(audio_path) | ||||
|         total_chunks = (file_size + chunk_size - 1) // chunk_size  # 计算总分片数 | ||||
| 
 | ||||
|         meta = { | ||||
|             "type": "metadata", | ||||
|             "filename": filename, | ||||
|             "total_size": file_size, | ||||
|             "total_chunks": total_chunks, | ||||
|             "format": audio_path.split(".")[-1], | ||||
|         } | ||||
| 
 | ||||
|         print(f"发送文件源信息: {meta}") | ||||
| 
 | ||||
|         try: | ||||
|             await self.websocket.send(json.dumps(meta)) | ||||
|             print("已发送元数据,等待确认...") | ||||
|                  | ||||
|             # 等待元数据确认 | ||||
|             ack = await self.websocket.recv() | ||||
|             ack_data = json.loads(ack) | ||||
|             if ack_data["type"] != "ack": | ||||
|                 raise Exception(f"元数据发送失败: {ack_data}") | ||||
|              | ||||
|             print(f"开始发送文件分片: {filename}") | ||||
| 
 | ||||
|             # 分片发送音频数据 | ||||
|             with open(audio_path, "rb") as f: | ||||
|                 for chunk_id in range(total_chunks): | ||||
|                     # 读取分片数据 | ||||
|                     chunk_data = f.read(chunk_size) | ||||
|                     is_last = (chunk_id == total_chunks - 1) | ||||
|                          | ||||
|                     # 发送分片(Hex编码) | ||||
|                     chunk_msg = { | ||||
|                         "type": "chunk", | ||||
|                         "chunk_id": chunk_id, | ||||
|                         "is_last": is_last, | ||||
|                         "data": chunk_data.hex() | ||||
|                     } | ||||
|                     await self.websocket.send(json.dumps(chunk_msg)) | ||||
| 
 | ||||
|                     print(f"发送分片: {chunk_id}, 等待服务器接收") | ||||
|                          | ||||
|                     # 等待分片确认 | ||||
|                     chunk_ack = await self.websocket.recv() | ||||
|                     chunk_ack_data = json.loads(chunk_ack) | ||||
|                          | ||||
|                     if chunk_ack_data["type"] != "ack": | ||||
|                         raise Exception(f"分片 {chunk_id} 发送失败") | ||||
| 
 | ||||
|                     print(f"分片: {chunk_id} 接收成功") | ||||
|                          | ||||
|                     # 显示进度 | ||||
|                     progress = (chunk_id + 1) / total_chunks * 100 | ||||
|                     print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r") | ||||
| 
 | ||||
|             print("\n所有分片发送完成,等待识别结果...") | ||||
| 
 | ||||
|             # 循环监听,直到得到最终消息 | ||||
|             is_listen = True | ||||
|             while is_listen: | ||||
|                 # 3. 获取识别结果 | ||||
|                 result = await self.websocket.recv() | ||||
|                 result_data = json.loads(result) | ||||
|                 print(f"识别结果: {result_data}") | ||||
| 
 | ||||
|                 # 获取最终消息 | ||||
|                 if result_data["type"] == "final_result": | ||||
|                     is_listen = False | ||||
| 
 | ||||
|                 # 服务端返回异常    | ||||
|                 if result_data["type"] == "error": | ||||
|                     raise Exception(f"识别失败: {result_data['message']}") | ||||
| 
 | ||||
|             total_time = time.time() - self.start_time | ||||
|             print(f"\n识别完成 (总耗时: {total_time:.2f}秒)") | ||||
|          | ||||
|         except Exception as e: | ||||
|             print(f"发送过程出错: {str(e)}") | ||||
|             self.is_connected = False | ||||
|             return False | ||||
| 
 | ||||
|     async def receive_messages(self): | ||||
|         """持续接收服务端的响应消息""" | ||||
|         if not self.websocket or self.is_connected: | ||||
|             return | ||||
|          | ||||
|         try: | ||||
|             while self.is_connected: | ||||
|                 response = await self.websocket.recv() | ||||
|                 message = json.loads(response) | ||||
|                 print(f"服务端消息: {message}") | ||||
|                  | ||||
|         except websockets.exceptions.ConnectionClosed: | ||||
|             print("\n与服务端的连接已关闭") | ||||
|             self.is_connected = False | ||||
|         except Exception as e: | ||||
|             print(f"\n接收消息出错: {str(e)}") | ||||
|             self.is_connected = False | ||||
| 
 | ||||
|     async def close(self): | ||||
|         """关闭WebSocket连接""" | ||||
|         if self.websocket and self.is_connected: | ||||
|             self.is_connected = False | ||||
|             await self.websocket.close() | ||||
|             print("已关闭连接") | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| async def main(args): | ||||
|     encoded_model = quote(args.model) | ||||
|     gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}"  # HTTP基础地址(用于健康检查) | ||||
|     gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}"  # WebSocket地址(网关路由)  | ||||
|     print(f"gateway_ws_url: {gateway_ws_url}") | ||||
| 
 | ||||
|     # #http请求: 服务健康检查 | ||||
|     # health_check_url = f"{gateway_http_url}/health" | ||||
|     # try: | ||||
|     #     response = requests.get(health_check_url, timeout=10) | ||||
|     #     print("健康检查状态码:", response.status_code)   | ||||
|     #     print("健康检查响应内容:", response.json()) | ||||
|     # except Exception as e: | ||||
|     #     print(f"健康检查失败: {str(e)}") | ||||
|     #     return  # 健康检查失败则退出 | ||||
| 
 | ||||
|     # #http请求: 获取模型信息 | ||||
|     # model_url = f"{gateway_http_url}/model" | ||||
|     # try: | ||||
|     #     response = requests.get(model_url, timeout=10) | ||||
|     #     print("模型信息状态码:", response.status_code)   | ||||
|     #     print("模型信息响应内容:", response.json()) | ||||
|     # except Exception as e: | ||||
|     #     print(f"获取模型信息失败: {str(e)}") | ||||
|     #     return  # 获取模型信息失败则退出 | ||||
|      | ||||
|     # 构造请求头,携带 Bearer 令牌 | ||||
|     header = { | ||||
|         "Authorization": f"Bearer {args.token}"  # 核心:添加认证头 | ||||
|     } | ||||
|     # 初始化客户端并连接 | ||||
|     client = AudioClient(server_uri=gateway_ws_url, headers=header) | ||||
|     if not await client.connect(): | ||||
|         return  # 连接失败则退出 | ||||
|      | ||||
|     try: | ||||
|         # 启动一个任务持续接收消息(非阻塞) | ||||
|         receive_task = asyncio.create_task(client.receive_messages()) | ||||
| 
 | ||||
|         #发送音频文件 | ||||
|         file_path=args.audio_file | ||||
|         # 检查文件是否存在 | ||||
|         if not os.path.exists(file_path): | ||||
|             print(f"错误: 音频文件不存在 - {file_path}") | ||||
|             return | ||||
|         await client.send_audio(audio_path=file_path) | ||||
|          | ||||
|         # 等待所有任务完成 | ||||
|         await receive_task | ||||
|      | ||||
|     except Exception as e: | ||||
|         print(f"发生错误: {str(e)}") | ||||
|          | ||||
|     finally: | ||||
|         # 关闭连接 | ||||
|         await client.close() | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     # 配置命令行参数解析 | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="网关语音识别WebSocket测试脚本", | ||||
|         formatter_class=argparse.RawTextHelpFormatter | ||||
|     ) | ||||
| 
 | ||||
|     # 必传参数(网关连接+模型+音频文件) | ||||
|     required_args = parser.add_argument_group("必传参数") | ||||
|     required_args.add_argument( | ||||
|         "--gateway-host", | ||||
|         required=True, | ||||
|         help="网关IP或域名" | ||||
|     ) | ||||
|     required_args.add_argument( | ||||
|         "--gateway-port", | ||||
|         required=True, | ||||
|         type=int, | ||||
|         help="网关端口" | ||||
|     ) | ||||
|     required_args.add_argument( | ||||
|         "--model", | ||||
|         required=True, | ||||
|         help="语音识别模型ID" | ||||
|     ) | ||||
|     required_args.add_argument( | ||||
|         "--audio-file", | ||||
|         required=True, | ||||
|         help="本地音频文件路径" | ||||
|     ) | ||||
|     required_args.add_argument( | ||||
|         "--token", | ||||
|         required=True, | ||||
|         help="token" | ||||
|     ) | ||||
| 
 | ||||
|     # 解析参数并启动异步任务 | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     # 运行客户端 | ||||
|     asyncio.run(main(args)) | ||||
|      | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user