95 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			95 lines
		
	
	
		
			2.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import base64
 | ||
| import io
 | ||
| import logging
 | ||
| import os
 | ||
| import time
 | ||
| 
 | ||
| import numpy as np
 | ||
| from PIL import Image
 | ||
| from fastapi import FastAPI, Request, status
 | ||
| from fastapi.responses import RedirectResponse
 | ||
| 
 | ||
| # 导入你的DetectionTool类
 | ||
| from utils.detection_tool import DetectionTool  # 请替换为实际模块路径
 | ||
| 
 | ||
| """
 | ||
| pip install ultralytics
 | ||
| pip install fastapi
 | ||
| """
 | ||
| 
 | ||
| # 基础配置
 | ||
| logging.basicConfig(
 | ||
|     level=logging.INFO,
 | ||
|     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
 | ||
| )
 | ||
| logger = logging.getLogger(__name__)
 | ||
| save_dir = os.getenv("SAVE_DIR", "save_dir")
 | ||
| 
 | ||
| # 初始化检测工具
 | ||
| detector = DetectionTool()
 | ||
| 
 | ||
| # 创建FastAPI实例
 | ||
| app = FastAPI()
 | ||
| 
 | ||
| 
 | ||
| def image_to_numpy(image_source):
 | ||
|     """将图片源转换为numpy数组"""
 | ||
|     if isinstance(image_source, str):
 | ||
|         # 处理文件路径或URL
 | ||
|         if image_source.startswith(("http://", "https://")):
 | ||
|             import requests
 | ||
|             response = requests.get(image_source)
 | ||
|             if response.status_code != 200:
 | ||
|                 raise ValueError(f"无法获取图片,状态码: {response.status_code}")
 | ||
|             image = Image.open(io.BytesIO(response.content))
 | ||
|         else:
 | ||
|             # 本地文件路径
 | ||
|             image = Image.open(image_source)
 | ||
|     elif isinstance(image_source, bytes):
 | ||
|         # 二进制数据
 | ||
|         image = Image.open(io.BytesIO(image_source))
 | ||
|     else:
 | ||
|         raise TypeError("不支持的图片源类型")
 | ||
| 
 | ||
|     # 转换为RGB并转为numpy数组
 | ||
|     return np.array(image.convert("RGB"))
 | ||
| 
 | ||
| 
 | ||
| @app.post(f"/v1/predict")
 | ||
| async def predict(request: Request):
 | ||
|     """核心预测接口:接收base64编码图片,返回检测结果"""
 | ||
|     data = await request.json()
 | ||
|     try:
 | ||
|         # 解码base64图片为numpy数组
 | ||
|         image_bytes = base64.b64decode(data["image"])
 | ||
|         img_np = image_to_numpy(image_bytes)
 | ||
| 
 | ||
|         # 使用你的检测工具获取结果
 | ||
|         image_detected, result = detector.detect_image(img_np, image_format="pillow")
 | ||
| 
 | ||
|         # 处理保存逻辑(如果需要)
 | ||
|         if save_dir is not None:
 | ||
|             os.makedirs(save_dir, exist_ok=True)
 | ||
|             # 可根据需要保存输入图片和结果
 | ||
|             timestamp = int(time.time())
 | ||
|             img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg")
 | ||
|             image_detected.save(img_save_path)
 | ||
| 
 | ||
|         return result
 | ||
| 
 | ||
|     except Exception as e:
 | ||
|         logger.error(f"预测出错: {str(e)}")
 | ||
|         return str(e)
 | ||
| 
 | ||
| 
 | ||
| @app.get("/")
 | ||
| async def index():
 | ||
|     """根路径重定向到API文档"""
 | ||
|     return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND)
 | ||
| 
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     import uvicorn
 | ||
| 
 | ||
|     uvicorn.run(app, host="0.0.0.0", port=8080)
 |