add file
This commit is contained in:
parent
67a07b1675
commit
f747aaacd4
251
client.py
Normal file
251
client.py
Normal file
@ -0,0 +1,251 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import websockets
|
||||
import requests
|
||||
from urllib.parse import quote
|
||||
|
||||
class AudioClient:
|
||||
def __init__(self, server_uri: str, headers):
|
||||
self.server_uri = server_uri # 服务端WebSocket地址
|
||||
self.headers = headers
|
||||
self.websocket: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self.start_time: float = 0 # 记录开始时间
|
||||
self.is_connected = False # 连接状态标记
|
||||
|
||||
async def connect(self):
|
||||
"""连接到WebSocket服务端"""
|
||||
try:
|
||||
self.websocket = await websockets.connect(self.server_uri, extra_headers=self.headers)
|
||||
self.is_connected = True
|
||||
print(f"已连接到服务端: {self.server_uri}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"连接失败: {str(e)}")
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def send_audio(self, audio_path: str, chunk_size: int = 512 * 1024):
|
||||
"""发送音频文件元数据(名称、大小、分片数等)"""
|
||||
if not self.websocket or not self.is_connected:
|
||||
print("连接未建立或已关闭,尝试重新连接...")
|
||||
if not await self.connect():
|
||||
return False
|
||||
|
||||
print("开始发送请求..........")
|
||||
|
||||
# 记录开始时间
|
||||
self.start_time = time.time()
|
||||
|
||||
# 音频文件信息
|
||||
filename = os.path.basename(audio_path)
|
||||
file_size = os.path.getsize(audio_path)
|
||||
total_chunks = (file_size + chunk_size - 1) // chunk_size # 计算总分片数
|
||||
|
||||
meta = {
|
||||
"type": "metadata",
|
||||
"filename": filename,
|
||||
"total_size": file_size,
|
||||
"total_chunks": total_chunks,
|
||||
"format": audio_path.split(".")[-1],
|
||||
}
|
||||
|
||||
print(f"发送文件源信息: {meta}")
|
||||
|
||||
try:
|
||||
await self.websocket.send(json.dumps(meta))
|
||||
print("已发送元数据,等待确认...")
|
||||
|
||||
# 等待元数据确认
|
||||
ack = await self.websocket.recv()
|
||||
ack_data = json.loads(ack)
|
||||
if ack_data["type"] != "ack":
|
||||
raise Exception(f"元数据发送失败: {ack_data}")
|
||||
|
||||
print(f"开始发送文件分片: {filename}")
|
||||
|
||||
# 分片发送音频数据
|
||||
with open(audio_path, "rb") as f:
|
||||
for chunk_id in range(total_chunks):
|
||||
# 读取分片数据
|
||||
chunk_data = f.read(chunk_size)
|
||||
is_last = (chunk_id == total_chunks - 1)
|
||||
|
||||
# 发送分片(Hex编码)
|
||||
chunk_msg = {
|
||||
"type": "chunk",
|
||||
"chunk_id": chunk_id,
|
||||
"is_last": is_last,
|
||||
"data": chunk_data.hex()
|
||||
}
|
||||
await self.websocket.send(json.dumps(chunk_msg))
|
||||
|
||||
print(f"发送分片: {chunk_id}, 等待服务器接收")
|
||||
|
||||
# 等待分片确认
|
||||
chunk_ack = await self.websocket.recv()
|
||||
chunk_ack_data = json.loads(chunk_ack)
|
||||
|
||||
if chunk_ack_data["type"] != "ack":
|
||||
raise Exception(f"分片 {chunk_id} 发送失败")
|
||||
|
||||
print(f"分片: {chunk_id} 接收成功")
|
||||
|
||||
# 显示进度
|
||||
progress = (chunk_id + 1) / total_chunks * 100
|
||||
print(f"发送进度: {progress:.1f}% ({chunk_id + 1}/{total_chunks})", end="\r")
|
||||
|
||||
print("\n所有分片发送完成,等待识别结果...")
|
||||
|
||||
# 循环监听,直到得到最终消息
|
||||
is_listen = True
|
||||
while is_listen:
|
||||
# 3. 获取识别结果
|
||||
result = await self.websocket.recv()
|
||||
result_data = json.loads(result)
|
||||
print(f"识别结果: {result_data}")
|
||||
|
||||
# 获取最终消息
|
||||
if result_data["type"] == "final_result":
|
||||
is_listen = False
|
||||
|
||||
# 服务端返回异常
|
||||
if result_data["type"] == "error":
|
||||
raise Exception(f"识别失败: {result_data['message']}")
|
||||
|
||||
total_time = time.time() - self.start_time
|
||||
print(f"\n识别完成 (总耗时: {total_time:.2f}秒)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"发送过程出错: {str(e)}")
|
||||
self.is_connected = False
|
||||
return False
|
||||
|
||||
async def receive_messages(self):
|
||||
"""持续接收服务端的响应消息"""
|
||||
if not self.websocket or self.is_connected:
|
||||
return
|
||||
|
||||
try:
|
||||
while self.is_connected:
|
||||
response = await self.websocket.recv()
|
||||
message = json.loads(response)
|
||||
print(f"服务端消息: {message}")
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
print("\n与服务端的连接已关闭")
|
||||
self.is_connected = False
|
||||
except Exception as e:
|
||||
print(f"\n接收消息出错: {str(e)}")
|
||||
self.is_connected = False
|
||||
|
||||
async def close(self):
|
||||
"""关闭WebSocket连接"""
|
||||
if self.websocket and self.is_connected:
|
||||
self.is_connected = False
|
||||
await self.websocket.close()
|
||||
print("已关闭连接")
|
||||
|
||||
|
||||
|
||||
async def main(args):
|
||||
encoded_model = quote(args.model)
|
||||
gateway_http_url = f"http://{args.gateway_host}:{args.gateway_port}" # HTTP基础地址(用于健康检查)
|
||||
gateway_ws_url = f"ws://{args.gateway_host}:{args.gateway_port}/learnware/models/openai/4pd/api/v1/voice/recognition/ws?model={encoded_model}" # WebSocket地址(网关路由)
|
||||
print(f"gateway_ws_url: {gateway_ws_url}")
|
||||
|
||||
# #http请求: 服务健康检查
|
||||
# health_check_url = f"{gateway_http_url}/health"
|
||||
# try:
|
||||
# response = requests.get(health_check_url, timeout=10)
|
||||
# print("健康检查状态码:", response.status_code)
|
||||
# print("健康检查响应内容:", response.json())
|
||||
# except Exception as e:
|
||||
# print(f"健康检查失败: {str(e)}")
|
||||
# return # 健康检查失败则退出
|
||||
|
||||
# #http请求: 获取模型信息
|
||||
# model_url = f"{gateway_http_url}/model"
|
||||
# try:
|
||||
# response = requests.get(model_url, timeout=10)
|
||||
# print("模型信息状态码:", response.status_code)
|
||||
# print("模型信息响应内容:", response.json())
|
||||
# except Exception as e:
|
||||
# print(f"获取模型信息失败: {str(e)}")
|
||||
# return # 获取模型信息失败则退出
|
||||
|
||||
# 构造请求头,携带 Bearer 令牌
|
||||
header = {
|
||||
"Authorization": f"Bearer {args.token}" # 核心:添加认证头
|
||||
}
|
||||
# 初始化客户端并连接
|
||||
client = AudioClient(server_uri=gateway_ws_url, headers=header)
|
||||
if not await client.connect():
|
||||
return # 连接失败则退出
|
||||
|
||||
try:
|
||||
# 启动一个任务持续接收消息(非阻塞)
|
||||
receive_task = asyncio.create_task(client.receive_messages())
|
||||
|
||||
#发送音频文件
|
||||
file_path=args.audio_file
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(file_path):
|
||||
print(f"错误: 音频文件不存在 - {file_path}")
|
||||
return
|
||||
await client.send_audio(audio_path=file_path)
|
||||
|
||||
# 等待所有任务完成
|
||||
await receive_task
|
||||
|
||||
except Exception as e:
|
||||
print(f"发生错误: {str(e)}")
|
||||
|
||||
finally:
|
||||
# 关闭连接
|
||||
await client.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置命令行参数解析
|
||||
parser = argparse.ArgumentParser(
|
||||
description="网关语音识别WebSocket测试脚本",
|
||||
formatter_class=argparse.RawTextHelpFormatter
|
||||
)
|
||||
|
||||
# 必传参数(网关连接+模型+音频文件)
|
||||
required_args = parser.add_argument_group("必传参数")
|
||||
required_args.add_argument(
|
||||
"--gateway-host",
|
||||
required=True,
|
||||
help="网关IP或域名"
|
||||
)
|
||||
required_args.add_argument(
|
||||
"--gateway-port",
|
||||
required=True,
|
||||
type=int,
|
||||
help="网关端口"
|
||||
)
|
||||
required_args.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
help="语音识别模型ID"
|
||||
)
|
||||
required_args.add_argument(
|
||||
"--audio-file",
|
||||
required=True,
|
||||
help="本地音频文件路径"
|
||||
)
|
||||
required_args.add_argument(
|
||||
"--token",
|
||||
required=True,
|
||||
help="token"
|
||||
)
|
||||
|
||||
# 解析参数并启动异步任务
|
||||
args = parser.parse_args()
|
||||
|
||||
# 运行客户端
|
||||
asyncio.run(main(args))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user