add file
This commit is contained in:
		
							parent
							
								
									a3b9b6a102
								
							
						
					
					
						commit
						eeaf19c4e5
					
				
							
								
								
									
										234
									
								
								utils/detection_tool.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										234
									
								
								utils/detection_tool.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,234 @@ | ||||
| import cv2 | ||||
| import numpy as np | ||||
| 
 | ||||
| from pathlib import Path | ||||
| from datetime import datetime | ||||
| from PIL import Image, ImageDraw, ImageFont | ||||
| from config import Config | ||||
| from ultralytics import YOLO | ||||
| 
 | ||||
| 
 | ||||
| # 检测工具类 | ||||
| class DetectionTool(object): | ||||
| 
 | ||||
|     def __init__(self, config: Config = None): | ||||
|         # 设置基本信息 | ||||
|         if config is None: | ||||
|             self.config = Config() | ||||
|         else: | ||||
|             self.config = config | ||||
| 
 | ||||
|         # 创建预测模型 | ||||
|         self.model = YOLO(self.config.weights_path) | ||||
| 
 | ||||
|     # 预测图片的主方法 | ||||
|     def detect_image(self, image_data, image_format="cv", only_return_result=False): | ||||
|         """ | ||||
|         :param image_data: 图片数据(需要为numpy数组格式) | ||||
|         :param image_format: 需要导出的图片格式(支持pillow/cv) | ||||
|         :param only_return_result: 只返回预测结果 | ||||
|         :return: 标注好的图片 | ||||
|         """ | ||||
|         # 获取预测结果 | ||||
|         pred_result = self.model(image_data, verbose=False, iou=0.35)[0].boxes | ||||
| 
 | ||||
|         # 分解预测结果 | ||||
|         pred_boxes, pred_classes, pred_scores = pred_result.xyxy, pred_result.cls, pred_result.conf | ||||
|         result = { | ||||
|             "boxes": pred_boxes.cpu().numpy(), | ||||
|             "classes": pred_classes.cpu().numpy(), | ||||
|             "scores": pred_scores.cpu().numpy() | ||||
|         } | ||||
| 
 | ||||
|         # 返回预测结果 | ||||
|         if only_return_result: | ||||
|             return result | ||||
| 
 | ||||
|         # 如果目标框数量大于0 | ||||
|         if len(pred_boxes) > 0: | ||||
|             # 将矩阵图片转换为Pillow对象 | ||||
|             image_data = Image.fromarray(image_data) | ||||
| 
 | ||||
|             # 设置字体 | ||||
|             font = ImageFont.truetype( | ||||
|                 font=self.config.font_path, | ||||
|                 size=np.floor(3e-2 * image_data.size[1] + 0.5).astype('int32') | ||||
|             ) | ||||
|             thickness = max((image_data.size[0] + image_data.size[1]) // 300, 1) | ||||
| 
 | ||||
|             # 创建画板 | ||||
|             image_draw = ImageDraw.Draw(image_data) | ||||
| 
 | ||||
|             # 遍历所预测的标注框信息 | ||||
|             for i, bndbox in enumerate(pred_boxes): | ||||
|                 # 获取类别和得分 | ||||
|                 class_name = self.config.classes_name[int(pred_classes[i])] | ||||
| 
 | ||||
|                 # 获取切割坐标 | ||||
|                 x_min, y_min, x_max, y_max = bndbox.cpu().numpy().astype("int32") | ||||
| 
 | ||||
|                 # 设置文本内容 | ||||
|                 text_content = f"{class_name}||{format(float(pred_scores[i]), '.2f')}" | ||||
| 
 | ||||
|                 # 画框框 | ||||
|                 for line_num in range(thickness): | ||||
|                     image_draw.rectangle( | ||||
|                         [x_min + line_num, y_min + line_num, x_max - line_num, y_max - line_num], | ||||
|                         outline=(255, 0, 0) | ||||
|                     ) | ||||
|                 text_size = image_draw.textsize(text_content, font) | ||||
|                 text_content = text_content.encode('utf-8') | ||||
| 
 | ||||
|                 # 计算文本的坐标 | ||||
|                 text_y_min = y_min - text_size[1] | ||||
|                 text_x_min = x_min | ||||
|                 text_y_max = y_min | ||||
|                 text_x_max = x_min + text_size[0] | ||||
| 
 | ||||
|                 # 添加文本内容 | ||||
|                 image_draw.rectangle( | ||||
|                     [(text_x_min, text_y_min), (text_x_max, text_y_max)], | ||||
|                     fill=(255, 0, 0) | ||||
|                 ) | ||||
|                 image_draw.text((text_x_min, text_y_min), str(text_content, "utf-8"), fill=(0, 0, 0), font=font) | ||||
| 
 | ||||
|             # 如果选择格式为cv, 则进行转换格式 | ||||
|             if image_format == "cv": | ||||
|                 image_data = np.asarray(image_data, dtype="uint8") | ||||
| 
 | ||||
|         else: | ||||
|             # 如果选择格式为pillow, 则进行转换格式 | ||||
|             if image_format == "pillow": | ||||
|                 image_data = Image.fromarray(image_data) | ||||
| 
 | ||||
|         return image_data, result | ||||
| 
 | ||||
|     # 测试对图片进行标注 | ||||
|     def test_image(self): | ||||
|         # 开启预测 | ||||
|         while True: | ||||
|             # 输入图片路径 | ||||
|             image_path = Path(input("请输入图片的路径:")) | ||||
| 
 | ||||
|             # 读取图片 | ||||
|             image_data = self.load_image(str(image_path)) | ||||
| 
 | ||||
|             # 预测 | ||||
|             image_data, _ = self.detect_image(image_data, "pillow") | ||||
| 
 | ||||
|             # 展示 | ||||
|             image_data.show("images_save-pic") | ||||
| 
 | ||||
|             # 是否继续 | ||||
|             is_next = input("是否继续?输入(y/n):") | ||||
|             if is_next == "n": | ||||
|                 break | ||||
| 
 | ||||
|     def test_video(self): | ||||
|         # 输入视频路径 | ||||
|         video_path = Path(str(input("请输入视频的路径:"))) | ||||
| 
 | ||||
|         # 输入保存的路径 | ||||
|         video_save = Path(str(input("请输入保存的文件夹路径:"))) | ||||
|         cap = cv2.VideoCapture(str(video_path)) | ||||
| 
 | ||||
|         # 获取视频参数 | ||||
|         frame_speed = int(cap.get(5)) | ||||
|         num_frame = int(cap.get(7)) | ||||
|         width = int(cap.get(3)) | ||||
|         height = int(cap.get(4)) | ||||
| 
 | ||||
|         # 设置导出视频参数 | ||||
|         fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') | ||||
|         out = cv2.VideoWriter(str(video_save.joinpath(f"{video_path.name[:-4]}.mp4")), fourcc, frame_speed, | ||||
|                               (width, height), True) | ||||
| 
 | ||||
|         count = 0 | ||||
| 
 | ||||
|         while cap.isOpened(): | ||||
|             # 初始时间记录 | ||||
|             start_time = datetime.now() | ||||
| 
 | ||||
|             # 读取一帧数据 | ||||
|             ret, frame_image = cap.read() | ||||
|             if not ret: | ||||
|                 break | ||||
| 
 | ||||
|             # 进行标注 | ||||
|             frame_image = cv2.cvtColor(frame_image, cv2.COLOR_BGR2RGB) | ||||
|             frame_image, _ = self.detect_image(frame_image) | ||||
|             frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) | ||||
| 
 | ||||
|             # 写入视频 | ||||
|             out.write(frame_image) | ||||
|             count += 1 | ||||
| 
 | ||||
|             # 结束时间 | ||||
|             end_time = datetime.now() | ||||
| 
 | ||||
|             # 剩余时间 | ||||
|             need_time = (end_time - start_time).microseconds / 1000000 * (num_frame - count) | ||||
|             fps = 1000 / ((end_time - start_time).microseconds / 1000) | ||||
| 
 | ||||
|             print( | ||||
|                 f"\r正在渲染: {count}/{num_frame}|FPS:{fps}|大约剩余时间:{format(need_time, '.0f')}s", | ||||
|                 end="", | ||||
|                 flush=True | ||||
|             ) | ||||
| 
 | ||||
|         out.release() | ||||
|         print("\n渲染完成") | ||||
| 
 | ||||
|     def test_camera(self, cap_index=0): | ||||
|         # 选择数据驱动来源 | ||||
|         cap = cv2.VideoCapture(cap_index) | ||||
| 
 | ||||
|         while cap.isOpened(): | ||||
|             ret, origin_image = cap.read()  # 读取一帧数据 | ||||
|             if not ret: | ||||
|                 break | ||||
| 
 | ||||
|             # 记录开始时间 | ||||
|             begin_time = datetime.now() | ||||
| 
 | ||||
|             # 预测 | ||||
|             frame_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB) | ||||
|             frame_image, _ = self.detect_image(frame_image) | ||||
|             frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) | ||||
| 
 | ||||
|             # 记录结束时间 | ||||
|             end_time = begin_time.now() | ||||
| 
 | ||||
|             # 输出FPS | ||||
|             fps = 1000 / ((end_time - begin_time).microseconds / 1000) | ||||
|             print(f"\rFPS:{fps}", flush=True, end="") | ||||
| 
 | ||||
|             # 显示图像 | ||||
|             control = cv2.waitKey(1)  # 1ms后切换下一帧图像 | ||||
|             if control & 0xFF == ord('q'): | ||||
|                 # 释放摄像头并销毁所有窗口 | ||||
|                 cap.release() | ||||
|                 cv2.destroyAllWindows() | ||||
|                 break | ||||
|             cv2.imshow("predict", frame_image) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     # 读取图片 | ||||
|     def load_image(image_path, image_format="cv"): | ||||
|         """ | ||||
|         :param image_path: 图片路径 | ||||
|         :param image_format: 需要导出的图片格式(支持pillow/cv) | ||||
|         :return: 返回处理好的图片数据 | ||||
|         """ | ||||
|         # 读取 | ||||
|         image_data = Image.open(image_path) | ||||
| 
 | ||||
|         # 转换格式 | ||||
|         if image_data.mode != "RGB": | ||||
|             image_data = image_data.convert("RGB") | ||||
| 
 | ||||
|         # 根据图片格式参数改变图片格式 | ||||
|         if image_format == "cv": | ||||
|             image_data = np.array(image_data, dtype="uint8") | ||||
| 
 | ||||
|         return image_data | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user