diff --git a/main.py b/main.py new file mode 100644 index 0000000..9474657 --- /dev/null +++ b/main.py @@ -0,0 +1,94 @@ +import base64 +import io +import logging +import os +import time + +import numpy as np +from PIL import Image +from fastapi import FastAPI, Request, status +from fastapi.responses import RedirectResponse + +# 导入你的DetectionTool类 +from utils.detection_tool import DetectionTool # 请替换为实际模块路径 + +""" +pip install ultralytics +pip install fastapi +""" + +# 基础配置 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) +save_dir = os.getenv("SAVE_DIR", "save_dir") + +# 初始化检测工具 +detector = DetectionTool() + +# 创建FastAPI实例 +app = FastAPI() + + +def image_to_numpy(image_source): + """将图片源转换为numpy数组""" + if isinstance(image_source, str): + # 处理文件路径或URL + if image_source.startswith(("http://", "https://")): + import requests + response = requests.get(image_source) + if response.status_code != 200: + raise ValueError(f"无法获取图片,状态码: {response.status_code}") + image = Image.open(io.BytesIO(response.content)) + else: + # 本地文件路径 + image = Image.open(image_source) + elif isinstance(image_source, bytes): + # 二进制数据 + image = Image.open(io.BytesIO(image_source)) + else: + raise TypeError("不支持的图片源类型") + + # 转换为RGB并转为numpy数组 + return np.array(image.convert("RGB")) + + +@app.post(f"/v1/predict") +async def predict(request: Request): + """核心预测接口:接收base64编码图片,返回检测结果""" + data = await request.json() + try: + # 解码base64图片为numpy数组 + image_bytes = base64.b64decode(data["image"]) + img_np = image_to_numpy(image_bytes) + + # 使用你的检测工具获取结果 + image_detected, result = detector.detect_image(img_np, image_format="pillow") + + # 处理保存逻辑(如果需要) + if save_dir is not None: + os.makedirs(save_dir, exist_ok=True) + # 可根据需要保存输入图片和结果 + timestamp = int(time.time()) + img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg") + image_detected.save(img_save_path) + + return result + + except Exception as e: + logger.error(f"预测出错: {str(e)}") + return str(e) + + +@app.get("/") +async def index(): + """根路径重定向到API文档""" + return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8080)