add file
This commit is contained in:
		
							parent
							
								
									9f1d98f159
								
							
						
					
					
						commit
						a3b9b6a102
					
				
							
								
								
									
										94
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										94
									
								
								main.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,94 @@ | ||||
| import base64 | ||||
| import io | ||||
| import logging | ||||
| import os | ||||
| import time | ||||
| 
 | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| from fastapi import FastAPI, Request, status | ||||
| from fastapi.responses import RedirectResponse | ||||
| 
 | ||||
| # 导入你的DetectionTool类 | ||||
| from utils.detection_tool import DetectionTool  # 请替换为实际模块路径 | ||||
| 
 | ||||
| """ | ||||
| pip install ultralytics | ||||
| pip install fastapi | ||||
| """ | ||||
| 
 | ||||
| # 基础配置 | ||||
| logging.basicConfig( | ||||
|     level=logging.INFO, | ||||
|     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | ||||
| ) | ||||
| logger = logging.getLogger(__name__) | ||||
| save_dir = os.getenv("SAVE_DIR", "save_dir") | ||||
| 
 | ||||
| # 初始化检测工具 | ||||
| detector = DetectionTool() | ||||
| 
 | ||||
| # 创建FastAPI实例 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| 
 | ||||
| def image_to_numpy(image_source): | ||||
|     """将图片源转换为numpy数组""" | ||||
|     if isinstance(image_source, str): | ||||
|         # 处理文件路径或URL | ||||
|         if image_source.startswith(("http://", "https://")): | ||||
|             import requests | ||||
|             response = requests.get(image_source) | ||||
|             if response.status_code != 200: | ||||
|                 raise ValueError(f"无法获取图片,状态码: {response.status_code}") | ||||
|             image = Image.open(io.BytesIO(response.content)) | ||||
|         else: | ||||
|             # 本地文件路径 | ||||
|             image = Image.open(image_source) | ||||
|     elif isinstance(image_source, bytes): | ||||
|         # 二进制数据 | ||||
|         image = Image.open(io.BytesIO(image_source)) | ||||
|     else: | ||||
|         raise TypeError("不支持的图片源类型") | ||||
| 
 | ||||
|     # 转换为RGB并转为numpy数组 | ||||
|     return np.array(image.convert("RGB")) | ||||
| 
 | ||||
| 
 | ||||
| @app.post(f"/v1/predict") | ||||
| async def predict(request: Request): | ||||
|     """核心预测接口:接收base64编码图片,返回检测结果""" | ||||
|     data = await request.json() | ||||
|     try: | ||||
|         # 解码base64图片为numpy数组 | ||||
|         image_bytes = base64.b64decode(data["image"]) | ||||
|         img_np = image_to_numpy(image_bytes) | ||||
| 
 | ||||
|         # 使用你的检测工具获取结果 | ||||
|         image_detected, result = detector.detect_image(img_np, image_format="pillow") | ||||
| 
 | ||||
|         # 处理保存逻辑(如果需要) | ||||
|         if save_dir is not None: | ||||
|             os.makedirs(save_dir, exist_ok=True) | ||||
|             # 可根据需要保存输入图片和结果 | ||||
|             timestamp = int(time.time()) | ||||
|             img_save_path = os.path.join(save_dir, f"input_{timestamp}.jpg") | ||||
|             image_detected.save(img_save_path) | ||||
| 
 | ||||
|         return result | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         logger.error(f"预测出错: {str(e)}") | ||||
|         return str(e) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def index(): | ||||
|     """根路径重定向到API文档""" | ||||
|     return RedirectResponse(url="/docs", status_code=status.HTTP_302_FOUND) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     import uvicorn | ||||
| 
 | ||||
|     uvicorn.run(app, host="0.0.0.0", port=8080) | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user