add file
This commit is contained in:
parent
9f1d98f159
commit
a3b9b6a102
94
main.py
Normal file
94
main.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from fastapi import FastAPI, Request, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
|
|
||||||
|
# 导入你的DetectionTool类
|
||||||
|
from utils.detection_tool import DetectionTool # 请替换为实际模块路径
|
||||||
|
|
||||||
|
"""
|
||||||
|
pip install ultralytics
|
||||||
|
pip install fastapi
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 基础配置
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
save_dir = os.getenv("SAVE_DIR", "save_dir")
|
||||||
|
|
||||||
|
# 初始化检测工具
|
||||||
|
detector = DetectionTool()
|
||||||
|
|
||||||
|
# 创建FastAPI实例
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_numpy(image_source):
|
||||||
|
"""将图片源转换为numpy数组"""
|
||||||
|
if isinstance(image_source, str):
|
||||||
|
# 处理文件路径或URL
|
||||||
|
if image_source.startswith(("http://", "https://")):
|
||||||
|
import requests
|
||||||
|
response = requests.get(image_source)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ValueError(f"无法获取图片,状态码: {response.status_code}")
|
||||||
|
image = Image.open(io.BytesIO(response.content))
|
||||||
|
else:
|
||||||
|
# 本地文件路径
|
||||||
|
image = Image.open(image_source)
|
||||||
|
elif isinstance(image_source, bytes):
|
||||||
|
# 二进制数据
|
||||||
|
image = Image.open(io.BytesIO(image_source))
|
||||||
|
else:
|
||||||
|
raise TypeError("不支持的图片源类型")
|
||||||
|
|
||||||
|
# 转换为RGB并转为numpy数组
|
||||||
|
return np.array(image.convert("RGB"))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post(f"/v1/predict")
|
||||||
|
async def predict(request: Request):
|
||||||
|
"""核心预测接口:接收base64编码图片,返回检测结果"""
|
||||||
|
data = await request.json()
|
||||||
|
try:
|
||||||
|
# 解码base64图片为numpy数组
|
||||||
|
image_bytes = base64.b64decode(data["image"])
|
||||||
|
img_np = image_to_numpy(image_bytes)
|
||||||
|
|
||||||
|
# 使用你的检测工具获取结果
|
||||||
|
image_detected, result = detector.detect_image(img_np, image_format="pillow")
|
||||||
|
|
||||||
|
# 处理保存逻辑(如果需要)
|
||||||
|
if save_dir is not None:
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
# 可根据需要保存输入图片和结果
|
||||||
|
timestamp = int(time.time())
|
||||||
|
img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg")
|
||||||
|
image_detected.save(img_save_path)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预测出错: {str(e)}")
|
||||||
|
return str(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def index():
|
||||||
|
"""根路径重定向到API文档"""
|
||||||
|
return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||||
Loading…
Reference in New Issue
Block a user