95 lines
2.7 KiB
Python
95 lines
2.7 KiB
Python
|
|
import base64
|
|||
|
|
import io
|
|||
|
|
import logging
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
|
|||
|
|
import numpy as np
|
|||
|
|
from PIL import Image
|
|||
|
|
from fastapi import FastAPI, Request, status
|
|||
|
|
from fastapi.responses import RedirectResponse
|
|||
|
|
|
|||
|
|
# 导入你的DetectionTool类
|
|||
|
|
from utils.detection_tool import DetectionTool # 请替换为实际模块路径
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
pip install ultralytics
|
|||
|
|
pip install fastapi
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# 基础配置
|
|||
|
|
logging.basicConfig(
|
|||
|
|
level=logging.INFO,
|
|||
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|||
|
|
)
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
save_dir = os.getenv("SAVE_DIR", "save_dir")
|
|||
|
|
|
|||
|
|
# 初始化检测工具
|
|||
|
|
detector = DetectionTool()
|
|||
|
|
|
|||
|
|
# 创建FastAPI实例
|
|||
|
|
app = FastAPI()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def image_to_numpy(image_source):
|
|||
|
|
"""将图片源转换为numpy数组"""
|
|||
|
|
if isinstance(image_source, str):
|
|||
|
|
# 处理文件路径或URL
|
|||
|
|
if image_source.startswith(("http://", "https://")):
|
|||
|
|
import requests
|
|||
|
|
response = requests.get(image_source)
|
|||
|
|
if response.status_code != 200:
|
|||
|
|
raise ValueError(f"无法获取图片,状态码: {response.status_code}")
|
|||
|
|
image = Image.open(io.BytesIO(response.content))
|
|||
|
|
else:
|
|||
|
|
# 本地文件路径
|
|||
|
|
image = Image.open(image_source)
|
|||
|
|
elif isinstance(image_source, bytes):
|
|||
|
|
# 二进制数据
|
|||
|
|
image = Image.open(io.BytesIO(image_source))
|
|||
|
|
else:
|
|||
|
|
raise TypeError("不支持的图片源类型")
|
|||
|
|
|
|||
|
|
# 转换为RGB并转为numpy数组
|
|||
|
|
return np.array(image.convert("RGB"))
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.post(f"/v1/predict")
|
|||
|
|
async def predict(request: Request):
|
|||
|
|
"""核心预测接口:接收base64编码图片,返回检测结果"""
|
|||
|
|
data = await request.json()
|
|||
|
|
try:
|
|||
|
|
# 解码base64图片为numpy数组
|
|||
|
|
image_bytes = base64.b64decode(data["image"])
|
|||
|
|
img_np = image_to_numpy(image_bytes)
|
|||
|
|
|
|||
|
|
# 使用你的检测工具获取结果
|
|||
|
|
image_detected, result = detector.detect_image(img_np, image_format="pillow")
|
|||
|
|
|
|||
|
|
# 处理保存逻辑(如果需要)
|
|||
|
|
if save_dir is not None:
|
|||
|
|
os.makedirs(save_dir, exist_ok=True)
|
|||
|
|
# 可根据需要保存输入图片和结果
|
|||
|
|
timestamp = int(time.time())
|
|||
|
|
img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg")
|
|||
|
|
image_detected.save(img_save_path)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"预测出错: {str(e)}")
|
|||
|
|
return str(e)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.get("/")
|
|||
|
|
async def index():
|
|||
|
|
"""根路径重定向到API文档"""
|
|||
|
|
return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND)
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import uvicorn
|
|||
|
|
|
|||
|
|
uvicorn.run(app, host="0.0.0.0", port=8080)
|