import base64 import io import logging import os import time import numpy as np from PIL import Image from fastapi import FastAPI, Request, status from fastapi.responses import RedirectResponse # 导入你的DetectionTool类 from utils.detection_tool import DetectionTool # 请替换为实际模块路径 """ pip install ultralytics pip install fastapi """ # 基础配置 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) save_dir = os.getenv("SAVE_DIR", "save_dir") # 初始化检测工具 detector = DetectionTool() # 创建FastAPI实例 app = FastAPI() def image_to_numpy(image_source): """将图片源转换为numpy数组""" if isinstance(image_source, str): # 处理文件路径或URL if image_source.startswith(("http://", "https://")): import requests response = requests.get(image_source) if response.status_code != 200: raise ValueError(f"无法获取图片,状态码: {response.status_code}") image = Image.open(io.BytesIO(response.content)) else: # 本地文件路径 image = Image.open(image_source) elif isinstance(image_source, bytes): # 二进制数据 image = Image.open(io.BytesIO(image_source)) else: raise TypeError("不支持的图片源类型") # 转换为RGB并转为numpy数组 return np.array(image.convert("RGB")) @app.post(f"/v1/predict") async def predict(request: Request): """核心预测接口:接收base64编码图片,返回检测结果""" data = await request.json() try: # 解码base64图片为numpy数组 image_bytes = base64.b64decode(data["image"]) img_np = image_to_numpy(image_bytes) # 使用你的检测工具获取结果 image_detected, result = detector.detect_image(img_np, image_format="pillow") # 处理保存逻辑(如果需要) if save_dir is not None: os.makedirs(save_dir, exist_ok=True) # 可根据需要保存输入图片和结果 timestamp = int(time.time()) img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg") image_detected.save(img_save_path) return result except Exception as e: logger.error(f"预测出错: {str(e)}") return str(e) @app.get("/") async def index(): """根路径重定向到API文档""" return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8080)