This commit is contained in:
weitong 2025-10-30 09:59:06 +08:00
parent 9f1d98f159
commit a3b9b6a102

94
main.py Normal file
View File

@ -0,0 +1,94 @@
import base64
import io
import logging
import os
import time
import numpy as np
from PIL import Image
from fastapi import FastAPI, Request, status
from fastapi.responses import RedirectResponse
# 导入你的DetectionTool类
from utils.detection_tool import DetectionTool # 请替换为实际模块路径
"""
pip install ultralytics
pip install fastapi
"""
# 基础配置
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
save_dir = os.getenv("SAVE_DIR", "save_dir")
# 初始化检测工具
detector = DetectionTool()
# 创建FastAPI实例
app = FastAPI()
def image_to_numpy(image_source):
"""将图片源转换为numpy数组"""
if isinstance(image_source, str):
# 处理文件路径或URL
if image_source.startswith(("http://", "https://")):
import requests
response = requests.get(image_source)
if response.status_code != 200:
raise ValueError(f"无法获取图片,状态码: {response.status_code}")
image = Image.open(io.BytesIO(response.content))
else:
# 本地文件路径
image = Image.open(image_source)
elif isinstance(image_source, bytes):
# 二进制数据
image = Image.open(io.BytesIO(image_source))
else:
raise TypeError("不支持的图片源类型")
# 转换为RGB并转为numpy数组
return np.array(image.convert("RGB"))
@app.post(f"/v1/predict")
async def predict(request: Request):
"""核心预测接口接收base64编码图片返回检测结果"""
data = await request.json()
try:
# 解码base64图片为numpy数组
image_bytes = base64.b64decode(data["image"])
img_np = image_to_numpy(image_bytes)
# 使用你的检测工具获取结果
image_detected, result = detector.detect_image(img_np, image_format="pillow")
# 处理保存逻辑(如果需要)
if save_dir is not None:
os.makedirs(save_dir, exist_ok=True)
# 可根据需要保存输入图片和结果
timestamp = int(time.time())
img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg")
image_detected.save(img_save_path)
return result
except Exception as e:
logger.error(f"预测出错: {str(e)}")
return str(e)
@app.get("/")
async def index():
"""根路径重定向到API文档"""
return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8080)