model-yolo11-weitong-id/utils/detection_tool.py
2025-10-30 09:59:51 +08:00

235 lines
7.9 KiB
Python

import cv2
import numpy as np
from pathlib import Path
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from config import Config
from ultralytics import YOLO
# 检测工具类
class DetectionTool(object):
def __init__(self, config: Config = None):
# 设置基本信息
if config is None:
self.config = Config()
else:
self.config = config
# 创建预测模型
self.model = YOLO(self.config.weights_path)
# 预测图片的主方法
def detect_image(self, image_data, image_format="cv", only_return_result=False):
"""
:param image_data: 图片数据(需要为numpy数组格式)
:param image_format: 需要导出的图片格式(支持pillow/cv)
:param only_return_result: 只返回预测结果
:return: 标注好的图片
"""
# 获取预测结果
pred_result = self.model(image_data, verbose=False, iou=0.35)[0].boxes
# 分解预测结果
pred_boxes, pred_classes, pred_scores = pred_result.xyxy, pred_result.cls, pred_result.conf
result = {
"boxes": pred_boxes.cpu().numpy(),
"classes": pred_classes.cpu().numpy(),
"scores": pred_scores.cpu().numpy()
}
# 返回预测结果
if only_return_result:
return result
# 如果目标框数量大于0
if len(pred_boxes) > 0:
# 将矩阵图片转换为Pillow对象
image_data = Image.fromarray(image_data)
# 设置字体
font = ImageFont.truetype(
font=self.config.font_path,
size=np.floor(3e-2 * image_data.size[1] + 0.5).astype('int32')
)
thickness = max((image_data.size[0] + image_data.size[1]) // 300, 1)
# 创建画板
image_draw = ImageDraw.Draw(image_data)
# 遍历所预测的标注框信息
for i, bndbox in enumerate(pred_boxes):
# 获取类别和得分
class_name = self.config.classes_name[int(pred_classes[i])]
# 获取切割坐标
x_min, y_min, x_max, y_max = bndbox.cpu().numpy().astype("int32")
# 设置文本内容
text_content = f"{class_name}||{format(float(pred_scores[i]), '.2f')}"
# 画框框
for line_num in range(thickness):
image_draw.rectangle(
[x_min + line_num, y_min + line_num, x_max - line_num, y_max - line_num],
outline=(255, 0, 0)
)
text_size = image_draw.textsize(text_content, font)
text_content = text_content.encode('utf-8')
# 计算文本的坐标
text_y_min = y_min - text_size[1]
text_x_min = x_min
text_y_max = y_min
text_x_max = x_min + text_size[0]
# 添加文本内容
image_draw.rectangle(
[(text_x_min, text_y_min), (text_x_max, text_y_max)],
fill=(255, 0, 0)
)
image_draw.text((text_x_min, text_y_min), str(text_content, "utf-8"), fill=(0, 0, 0), font=font)
# 如果选择格式为cv, 则进行转换格式
if image_format == "cv":
image_data = np.asarray(image_data, dtype="uint8")
else:
# 如果选择格式为pillow, 则进行转换格式
if image_format == "pillow":
image_data = Image.fromarray(image_data)
return image_data, result
# 测试对图片进行标注
def test_image(self):
# 开启预测
while True:
# 输入图片路径
image_path = Path(input("请输入图片的路径:"))
# 读取图片
image_data = self.load_image(str(image_path))
# 预测
image_data, _ = self.detect_image(image_data, "pillow")
# 展示
image_data.show("images_save-pic")
# 是否继续
is_next = input("是否继续?输入(y/n):")
if is_next == "n":
break
def test_video(self):
# 输入视频路径
video_path = Path(str(input("请输入视频的路径:")))
# 输入保存的路径
video_save = Path(str(input("请输入保存的文件夹路径:")))
cap = cv2.VideoCapture(str(video_path))
# 获取视频参数
frame_speed = int(cap.get(5))
num_frame = int(cap.get(7))
width = int(cap.get(3))
height = int(cap.get(4))
# 设置导出视频参数
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
out = cv2.VideoWriter(str(video_save.joinpath(f"{video_path.name[:-4]}.mp4")), fourcc, frame_speed,
(width, height), True)
count = 0
while cap.isOpened():
# 初始时间记录
start_time = datetime.now()
# 读取一帧数据
ret, frame_image = cap.read()
if not ret:
break
# 进行标注
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_BGR2RGB)
frame_image, _ = self.detect_image(frame_image)
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
# 写入视频
out.write(frame_image)
count += 1
# 结束时间
end_time = datetime.now()
# 剩余时间
need_time = (end_time - start_time).microseconds / 1000000 * (num_frame - count)
fps = 1000 / ((end_time - start_time).microseconds / 1000)
print(
f"\r正在渲染: {count}/{num_frame}|FPS:{fps}|大约剩余时间:{format(need_time, '.0f')}s",
end="",
flush=True
)
out.release()
print("\n渲染完成")
def test_camera(self, cap_index=0):
# 选择数据驱动来源
cap = cv2.VideoCapture(cap_index)
while cap.isOpened():
ret, origin_image = cap.read() # 读取一帧数据
if not ret:
break
# 记录开始时间
begin_time = datetime.now()
# 预测
frame_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
frame_image, _ = self.detect_image(frame_image)
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
# 记录结束时间
end_time = begin_time.now()
# 输出FPS
fps = 1000 / ((end_time - begin_time).microseconds / 1000)
print(f"\rFPS:{fps}", flush=True, end="")
# 显示图像
control = cv2.waitKey(1) # 1ms后切换下一帧图像
if control & 0xFF == ord('q'):
# 释放摄像头并销毁所有窗口
cap.release()
cv2.destroyAllWindows()
break
cv2.imshow("predict", frame_image)
@staticmethod
# 读取图片
def load_image(image_path, image_format="cv"):
"""
:param image_path: 图片路径
:param image_format: 需要导出的图片格式(支持pillow/cv)
:return: 返回处理好的图片数据
"""
# 读取
image_data = Image.open(image_path)
# 转换格式
if image_data.mode != "RGB":
image_data = image_data.convert("RGB")
# 根据图片格式参数改变图片格式
if image_format == "cv":
image_data = np.array(image_data, dtype="uint8")
return image_data