import cv2 import numpy as np from pathlib import Path from datetime import datetime from PIL import Image, ImageDraw, ImageFont from config import Config from ultralytics import YOLO # 检测工具类 class DetectionTool(object): def __init__(self, config: Config = None): # 设置基本信息 if config is None: self.config = Config() else: self.config = config # 创建预测模型 self.model = YOLO(self.config.weights_path) # 预测图片的主方法 def detect_image(self, image_data, image_format="cv", only_return_result=False): """ :param image_data: 图片数据(需要为numpy数组格式) :param image_format: 需要导出的图片格式(支持pillow/cv) :param only_return_result: 只返回预测结果 :return: 标注好的图片 """ # 获取预测结果 pred_result = self.model(image_data, verbose=False, iou=0.35)[0].boxes # 分解预测结果 pred_boxes, pred_classes, pred_scores = pred_result.xyxy, pred_result.cls, pred_result.conf result = { "boxes": pred_boxes.cpu().numpy(), "classes": pred_classes.cpu().numpy(), "scores": pred_scores.cpu().numpy() } # 返回预测结果 if only_return_result: return result # 如果目标框数量大于0 if len(pred_boxes) > 0: # 将矩阵图片转换为Pillow对象 image_data = Image.fromarray(image_data) # 设置字体 font = ImageFont.truetype( font=self.config.font_path, size=np.floor(3e-2 * image_data.size[1] + 0.5).astype('int32') ) thickness = max((image_data.size[0] + image_data.size[1]) // 300, 1) # 创建画板 image_draw = ImageDraw.Draw(image_data) # 遍历所预测的标注框信息 for i, bndbox in enumerate(pred_boxes): # 获取类别和得分 class_name = self.config.classes_name[int(pred_classes[i])] # 获取切割坐标 x_min, y_min, x_max, y_max = bndbox.cpu().numpy().astype("int32") # 设置文本内容 text_content = f"{class_name}||{format(float(pred_scores[i]), '.2f')}" # 画框框 for line_num in range(thickness): image_draw.rectangle( [x_min + line_num, y_min + line_num, x_max - line_num, y_max - line_num], outline=(255, 0, 0) ) text_size = image_draw.textsize(text_content, font) text_content = text_content.encode('utf-8') # 计算文本的坐标 text_y_min = y_min - text_size[1] text_x_min = x_min text_y_max = y_min text_x_max = x_min + text_size[0] # 添加文本内容 image_draw.rectangle( [(text_x_min, text_y_min), (text_x_max, text_y_max)], fill=(255, 0, 0) ) image_draw.text((text_x_min, text_y_min), str(text_content, "utf-8"), fill=(0, 0, 0), font=font) # 如果选择格式为cv, 则进行转换格式 if image_format == "cv": image_data = np.asarray(image_data, dtype="uint8") else: # 如果选择格式为pillow, 则进行转换格式 if image_format == "pillow": image_data = Image.fromarray(image_data) return image_data, result # 测试对图片进行标注 def test_image(self): # 开启预测 while True: # 输入图片路径 image_path = Path(input("请输入图片的路径:")) # 读取图片 image_data = self.load_image(str(image_path)) # 预测 image_data, _ = self.detect_image(image_data, "pillow") # 展示 image_data.show("images_save-pic") # 是否继续 is_next = input("是否继续?输入(y/n):") if is_next == "n": break def test_video(self): # 输入视频路径 video_path = Path(str(input("请输入视频的路径:"))) # 输入保存的路径 video_save = Path(str(input("请输入保存的文件夹路径:"))) cap = cv2.VideoCapture(str(video_path)) # 获取视频参数 frame_speed = int(cap.get(5)) num_frame = int(cap.get(7)) width = int(cap.get(3)) height = int(cap.get(4)) # 设置导出视频参数 fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') out = cv2.VideoWriter(str(video_save.joinpath(f"{video_path.name[:-4]}.mp4")), fourcc, frame_speed, (width, height), True) count = 0 while cap.isOpened(): # 初始时间记录 start_time = datetime.now() # 读取一帧数据 ret, frame_image = cap.read() if not ret: break # 进行标注 frame_image = cv2.cvtColor(frame_image, cv2.COLOR_BGR2RGB) frame_image, _ = self.detect_image(frame_image) frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) # 写入视频 out.write(frame_image) count += 1 # 结束时间 end_time = datetime.now() # 剩余时间 need_time = (end_time - start_time).microseconds / 1000000 * (num_frame - count) fps = 1000 / ((end_time - start_time).microseconds / 1000) print( f"\r正在渲染: {count}/{num_frame}|FPS:{fps}|大约剩余时间:{format(need_time, '.0f')}s", end="", flush=True ) out.release() print("\n渲染完成") def test_camera(self, cap_index=0): # 选择数据驱动来源 cap = cv2.VideoCapture(cap_index) while cap.isOpened(): ret, origin_image = cap.read() # 读取一帧数据 if not ret: break # 记录开始时间 begin_time = datetime.now() # 预测 frame_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB) frame_image, _ = self.detect_image(frame_image) frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) # 记录结束时间 end_time = begin_time.now() # 输出FPS fps = 1000 / ((end_time - begin_time).microseconds / 1000) print(f"\rFPS:{fps}", flush=True, end="") # 显示图像 control = cv2.waitKey(1) # 1ms后切换下一帧图像 if control & 0xFF == ord('q'): # 释放摄像头并销毁所有窗口 cap.release() cv2.destroyAllWindows() break cv2.imshow("predict", frame_image) @staticmethod # 读取图片 def load_image(image_path, image_format="cv"): """ :param image_path: 图片路径 :param image_format: 需要导出的图片格式(支持pillow/cv) :return: 返回处理好的图片数据 """ # 读取 image_data = Image.open(image_path) # 转换格式 if image_data.mode != "RGB": image_data = image_data.convert("RGB") # 根据图片格式参数改变图片格式 if image_format == "cv": image_data = np.array(image_data, dtype="uint8") return image_data