From eeaf19c4e5f3966449ef8bcf7f46863ffe1a673a Mon Sep 17 00:00:00 2001 From: weitong Date: Thu, 30 Oct 2025 09:59:51 +0800 Subject: [PATCH] add file --- utils/detection_tool.py | 234 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 234 insertions(+) create mode 100644 utils/detection_tool.py diff --git a/utils/detection_tool.py b/utils/detection_tool.py new file mode 100644 index 0000000..1103d2a --- /dev/null +++ b/utils/detection_tool.py @@ -0,0 +1,234 @@ +import cv2 +import numpy as np + +from pathlib import Path +from datetime import datetime +from PIL import Image, ImageDraw, ImageFont +from config import Config +from ultralytics import YOLO + + +# 检测工具类 +class DetectionTool(object): + + def __init__(self, config: Config = None): + # 设置基本信息 + if config is None: + self.config = Config() + else: + self.config = config + + # 创建预测模型 + self.model = YOLO(self.config.weights_path) + + # 预测图片的主方法 + def detect_image(self, image_data, image_format="cv", only_return_result=False): + """ + :param image_data: 图片数据(需要为numpy数组格式) + :param image_format: 需要导出的图片格式(支持pillow/cv) + :param only_return_result: 只返回预测结果 + :return: 标注好的图片 + """ + # 获取预测结果 + pred_result = self.model(image_data, verbose=False, iou=0.35)[0].boxes + + # 分解预测结果 + pred_boxes, pred_classes, pred_scores = pred_result.xyxy, pred_result.cls, pred_result.conf + result = { + "boxes": pred_boxes.cpu().numpy(), + "classes": pred_classes.cpu().numpy(), + "scores": pred_scores.cpu().numpy() + } + + # 返回预测结果 + if only_return_result: + return result + + # 如果目标框数量大于0 + if len(pred_boxes) > 0: + # 将矩阵图片转换为Pillow对象 + image_data = Image.fromarray(image_data) + + # 设置字体 + font = ImageFont.truetype( + font=self.config.font_path, + size=np.floor(3e-2 * image_data.size[1] + 0.5).astype('int32') + ) + thickness = max((image_data.size[0] + image_data.size[1]) // 300, 1) + + # 创建画板 + image_draw = ImageDraw.Draw(image_data) + + # 遍历所预测的标注框信息 + for i, bndbox in enumerate(pred_boxes): + # 获取类别和得分 + class_name = self.config.classes_name[int(pred_classes[i])] + + # 获取切割坐标 + x_min, y_min, x_max, y_max = bndbox.cpu().numpy().astype("int32") + + # 设置文本内容 + text_content = f"{class_name}||{format(float(pred_scores[i]), '.2f')}" + + # 画框框 + for line_num in range(thickness): + image_draw.rectangle( + [x_min + line_num, y_min + line_num, x_max - line_num, y_max - line_num], + outline=(255, 0, 0) + ) + text_size = image_draw.textsize(text_content, font) + text_content = text_content.encode('utf-8') + + # 计算文本的坐标 + text_y_min = y_min - text_size[1] + text_x_min = x_min + text_y_max = y_min + text_x_max = x_min + text_size[0] + + # 添加文本内容 + image_draw.rectangle( + [(text_x_min, text_y_min), (text_x_max, text_y_max)], + fill=(255, 0, 0) + ) + image_draw.text((text_x_min, text_y_min), str(text_content, "utf-8"), fill=(0, 0, 0), font=font) + + # 如果选择格式为cv, 则进行转换格式 + if image_format == "cv": + image_data = np.asarray(image_data, dtype="uint8") + + else: + # 如果选择格式为pillow, 则进行转换格式 + if image_format == "pillow": + image_data = Image.fromarray(image_data) + + return image_data, result + + # 测试对图片进行标注 + def test_image(self): + # 开启预测 + while True: + # 输入图片路径 + image_path = Path(input("请输入图片的路径:")) + + # 读取图片 + image_data = self.load_image(str(image_path)) + + # 预测 + image_data, _ = self.detect_image(image_data, "pillow") + + # 展示 + image_data.show("images_save-pic") + + # 是否继续 + is_next = input("是否继续?输入(y/n):") + if is_next == "n": + break + + def test_video(self): + # 输入视频路径 + video_path = Path(str(input("请输入视频的路径:"))) + + # 输入保存的路径 + video_save = Path(str(input("请输入保存的文件夹路径:"))) + cap = cv2.VideoCapture(str(video_path)) + + # 获取视频参数 + frame_speed = int(cap.get(5)) + num_frame = int(cap.get(7)) + width = int(cap.get(3)) + height = int(cap.get(4)) + + # 设置导出视频参数 + fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') + out = cv2.VideoWriter(str(video_save.joinpath(f"{video_path.name[:-4]}.mp4")), fourcc, frame_speed, + (width, height), True) + + count = 0 + + while cap.isOpened(): + # 初始时间记录 + start_time = datetime.now() + + # 读取一帧数据 + ret, frame_image = cap.read() + if not ret: + break + + # 进行标注 + frame_image = cv2.cvtColor(frame_image, cv2.COLOR_BGR2RGB) + frame_image, _ = self.detect_image(frame_image) + frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) + + # 写入视频 + out.write(frame_image) + count += 1 + + # 结束时间 + end_time = datetime.now() + + # 剩余时间 + need_time = (end_time - start_time).microseconds / 1000000 * (num_frame - count) + fps = 1000 / ((end_time - start_time).microseconds / 1000) + + print( + f"\r正在渲染: {count}/{num_frame}|FPS:{fps}|大约剩余时间:{format(need_time, '.0f')}s", + end="", + flush=True + ) + + out.release() + print("\n渲染完成") + + def test_camera(self, cap_index=0): + # 选择数据驱动来源 + cap = cv2.VideoCapture(cap_index) + + while cap.isOpened(): + ret, origin_image = cap.read() # 读取一帧数据 + if not ret: + break + + # 记录开始时间 + begin_time = datetime.now() + + # 预测 + frame_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB) + frame_image, _ = self.detect_image(frame_image) + frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR) + + # 记录结束时间 + end_time = begin_time.now() + + # 输出FPS + fps = 1000 / ((end_time - begin_time).microseconds / 1000) + print(f"\rFPS:{fps}", flush=True, end="") + + # 显示图像 + control = cv2.waitKey(1) # 1ms后切换下一帧图像 + if control & 0xFF == ord('q'): + # 释放摄像头并销毁所有窗口 + cap.release() + cv2.destroyAllWindows() + break + cv2.imshow("predict", frame_image) + + @staticmethod + # 读取图片 + def load_image(image_path, image_format="cv"): + """ + :param image_path: 图片路径 + :param image_format: 需要导出的图片格式(支持pillow/cv) + :return: 返回处理好的图片数据 + """ + # 读取 + image_data = Image.open(image_path) + + # 转换格式 + if image_data.mode != "RGB": + image_data = image_data.convert("RGB") + + # 根据图片格式参数改变图片格式 + if image_format == "cv": + image_data = np.array(image_data, dtype="uint8") + + return image_data