add file
This commit is contained in:
parent
9f1d98f159
commit
a3b9b6a102
94
main.py
Normal file
94
main.py
Normal file
@ -0,0 +1,94 @@
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from fastapi import FastAPI, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
# 导入你的DetectionTool类
|
||||
from utils.detection_tool import DetectionTool # 请替换为实际模块路径
|
||||
|
||||
"""
|
||||
pip install ultralytics
|
||||
pip install fastapi
|
||||
"""
|
||||
|
||||
# 基础配置
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
save_dir = os.getenv("SAVE_DIR", "save_dir")
|
||||
|
||||
# 初始化检测工具
|
||||
detector = DetectionTool()
|
||||
|
||||
# 创建FastAPI实例
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
def image_to_numpy(image_source):
|
||||
"""将图片源转换为numpy数组"""
|
||||
if isinstance(image_source, str):
|
||||
# 处理文件路径或URL
|
||||
if image_source.startswith(("http://", "https://")):
|
||||
import requests
|
||||
response = requests.get(image_source)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"无法获取图片,状态码: {response.status_code}")
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
else:
|
||||
# 本地文件路径
|
||||
image = Image.open(image_source)
|
||||
elif isinstance(image_source, bytes):
|
||||
# 二进制数据
|
||||
image = Image.open(io.BytesIO(image_source))
|
||||
else:
|
||||
raise TypeError("不支持的图片源类型")
|
||||
|
||||
# 转换为RGB并转为numpy数组
|
||||
return np.array(image.convert("RGB"))
|
||||
|
||||
|
||||
@app.post(f"/v1/predict")
|
||||
async def predict(request: Request):
|
||||
"""核心预测接口:接收base64编码图片,返回检测结果"""
|
||||
data = await request.json()
|
||||
try:
|
||||
# 解码base64图片为numpy数组
|
||||
image_bytes = base64.b64decode(data["image"])
|
||||
img_np = image_to_numpy(image_bytes)
|
||||
|
||||
# 使用你的检测工具获取结果
|
||||
image_detected, result = detector.detect_image(img_np, image_format="pillow")
|
||||
|
||||
# 处理保存逻辑(如果需要)
|
||||
if save_dir is not None:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# 可根据需要保存输入图片和结果
|
||||
timestamp = int(time.time())
|
||||
img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg")
|
||||
image_detected.save(img_save_path)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"预测出错: {str(e)}")
|
||||
return str(e)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def index():
|
||||
"""根路径重定向到API文档"""
|
||||
return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
Loading…
Reference in New Issue
Block a user