95 lines
2.7 KiB
Python
95 lines
2.7 KiB
Python
import base64
|
||
import io
|
||
import logging
|
||
import os
|
||
import time
|
||
|
||
import numpy as np
|
||
from PIL import Image
|
||
from fastapi import FastAPI, Request, status
|
||
from fastapi.responses import RedirectResponse
|
||
|
||
# 导入你的DetectionTool类
|
||
from utils.detection_tool import DetectionTool # 请替换为实际模块路径
|
||
|
||
"""
|
||
pip install ultralytics
|
||
pip install fastapi
|
||
"""
|
||
|
||
# 基础配置
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
save_dir = os.getenv("SAVE_DIR", "save_dir")
|
||
|
||
# 初始化检测工具
|
||
detector = DetectionTool()
|
||
|
||
# 创建FastAPI实例
|
||
app = FastAPI()
|
||
|
||
|
||
def image_to_numpy(image_source):
|
||
"""将图片源转换为numpy数组"""
|
||
if isinstance(image_source, str):
|
||
# 处理文件路径或URL
|
||
if image_source.startswith(("http://", "https://")):
|
||
import requests
|
||
response = requests.get(image_source)
|
||
if response.status_code != 200:
|
||
raise ValueError(f"无法获取图片,状态码: {response.status_code}")
|
||
image = Image.open(io.BytesIO(response.content))
|
||
else:
|
||
# 本地文件路径
|
||
image = Image.open(image_source)
|
||
elif isinstance(image_source, bytes):
|
||
# 二进制数据
|
||
image = Image.open(io.BytesIO(image_source))
|
||
else:
|
||
raise TypeError("不支持的图片源类型")
|
||
|
||
# 转换为RGB并转为numpy数组
|
||
return np.array(image.convert("RGB"))
|
||
|
||
|
||
@app.post(f"/v1/predict")
|
||
async def predict(request: Request):
|
||
"""核心预测接口:接收base64编码图片,返回检测结果"""
|
||
data = await request.json()
|
||
try:
|
||
# 解码base64图片为numpy数组
|
||
image_bytes = base64.b64decode(data["image"])
|
||
img_np = image_to_numpy(image_bytes)
|
||
|
||
# 使用你的检测工具获取结果
|
||
image_detected, result = detector.detect_image(img_np, image_format="pillow")
|
||
|
||
# 处理保存逻辑(如果需要)
|
||
if save_dir is not None:
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
# 可根据需要保存输入图片和结果
|
||
timestamp = int(time.time())
|
||
img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg")
|
||
image_detected.save(img_save_path)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测出错: {str(e)}")
|
||
return str(e)
|
||
|
||
|
||
@app.get("/")
|
||
async def index():
|
||
"""根路径重定向到API文档"""
|
||
return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8080)
|