This commit is contained in:
weitong 2025-10-30 09:59:51 +08:00
parent a3b9b6a102
commit eeaf19c4e5

234
utils/detection_tool.py Normal file
View File

@ -0,0 +1,234 @@
import cv2
import numpy as np
from pathlib import Path
from datetime import datetime
from PIL import Image, ImageDraw, ImageFont
from config import Config
from ultralytics import YOLO
# 检测工具类
class DetectionTool(object):
def __init__(self, config: Config = None):
# 设置基本信息
if config is None:
self.config = Config()
else:
self.config = config
# 创建预测模型
self.model = YOLO(self.config.weights_path)
# 预测图片的主方法
def detect_image(self, image_data, image_format="cv", only_return_result=False):
"""
:param image_data: 图片数据(需要为numpy数组格式)
:param image_format: 需要导出的图片格式(支持pillow/cv)
:param only_return_result: 只返回预测结果
:return: 标注好的图片
"""
# 获取预测结果
pred_result = self.model(image_data, verbose=False, iou=0.35)[0].boxes
# 分解预测结果
pred_boxes, pred_classes, pred_scores = pred_result.xyxy, pred_result.cls, pred_result.conf
result = {
"boxes": pred_boxes.cpu().numpy(),
"classes": pred_classes.cpu().numpy(),
"scores": pred_scores.cpu().numpy()
}
# 返回预测结果
if only_return_result:
return result
# 如果目标框数量大于0
if len(pred_boxes) > 0:
# 将矩阵图片转换为Pillow对象
image_data = Image.fromarray(image_data)
# 设置字体
font = ImageFont.truetype(
font=self.config.font_path,
size=np.floor(3e-2 * image_data.size[1] + 0.5).astype('int32')
)
thickness = max((image_data.size[0] + image_data.size[1]) // 300, 1)
# 创建画板
image_draw = ImageDraw.Draw(image_data)
# 遍历所预测的标注框信息
for i, bndbox in enumerate(pred_boxes):
# 获取类别和得分
class_name = self.config.classes_name[int(pred_classes[i])]
# 获取切割坐标
x_min, y_min, x_max, y_max = bndbox.cpu().numpy().astype("int32")
# 设置文本内容
text_content = f"{class_name}||{format(float(pred_scores[i]), '.2f')}"
# 画框框
for line_num in range(thickness):
image_draw.rectangle(
[x_min + line_num, y_min + line_num, x_max - line_num, y_max - line_num],
outline=(255, 0, 0)
)
text_size = image_draw.textsize(text_content, font)
text_content = text_content.encode('utf-8')
# 计算文本的坐标
text_y_min = y_min - text_size[1]
text_x_min = x_min
text_y_max = y_min
text_x_max = x_min + text_size[0]
# 添加文本内容
image_draw.rectangle(
[(text_x_min, text_y_min), (text_x_max, text_y_max)],
fill=(255, 0, 0)
)
image_draw.text((text_x_min, text_y_min), str(text_content, "utf-8"), fill=(0, 0, 0), font=font)
# 如果选择格式为cv, 则进行转换格式
if image_format == "cv":
image_data = np.asarray(image_data, dtype="uint8")
else:
# 如果选择格式为pillow, 则进行转换格式
if image_format == "pillow":
image_data = Image.fromarray(image_data)
return image_data, result
# 测试对图片进行标注
def test_image(self):
# 开启预测
while True:
# 输入图片路径
image_path = Path(input("请输入图片的路径:"))
# 读取图片
image_data = self.load_image(str(image_path))
# 预测
image_data, _ = self.detect_image(image_data, "pillow")
# 展示
image_data.show("images_save-pic")
# 是否继续
is_next = input("是否继续?输入(y/n):")
if is_next == "n":
break
def test_video(self):
# 输入视频路径
video_path = Path(str(input("请输入视频的路径:")))
# 输入保存的路径
video_save = Path(str(input("请输入保存的文件夹路径:")))
cap = cv2.VideoCapture(str(video_path))
# 获取视频参数
frame_speed = int(cap.get(5))
num_frame = int(cap.get(7))
width = int(cap.get(3))
height = int(cap.get(4))
# 设置导出视频参数
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
out = cv2.VideoWriter(str(video_save.joinpath(f"{video_path.name[:-4]}.mp4")), fourcc, frame_speed,
(width, height), True)
count = 0
while cap.isOpened():
# 初始时间记录
start_time = datetime.now()
# 读取一帧数据
ret, frame_image = cap.read()
if not ret:
break
# 进行标注
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_BGR2RGB)
frame_image, _ = self.detect_image(frame_image)
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
# 写入视频
out.write(frame_image)
count += 1
# 结束时间
end_time = datetime.now()
# 剩余时间
need_time = (end_time - start_time).microseconds / 1000000 * (num_frame - count)
fps = 1000 / ((end_time - start_time).microseconds / 1000)
print(
f"\r正在渲染: {count}/{num_frame}|FPS:{fps}|大约剩余时间:{format(need_time, '.0f')}s",
end="",
flush=True
)
out.release()
print("\n渲染完成")
def test_camera(self, cap_index=0):
# 选择数据驱动来源
cap = cv2.VideoCapture(cap_index)
while cap.isOpened():
ret, origin_image = cap.read() # 读取一帧数据
if not ret:
break
# 记录开始时间
begin_time = datetime.now()
# 预测
frame_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
frame_image, _ = self.detect_image(frame_image)
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
# 记录结束时间
end_time = begin_time.now()
# 输出FPS
fps = 1000 / ((end_time - begin_time).microseconds / 1000)
print(f"\rFPS:{fps}", flush=True, end="")
# 显示图像
control = cv2.waitKey(1) # 1ms后切换下一帧图像
if control & 0xFF == ord('q'):
# 释放摄像头并销毁所有窗口
cap.release()
cv2.destroyAllWindows()
break
cv2.imshow("predict", frame_image)
@staticmethod
# 读取图片
def load_image(image_path, image_format="cv"):
"""
:param image_path: 图片路径
:param image_format: 需要导出的图片格式(支持pillow/cv)
:return: 返回处理好的图片数据
"""
# 读取
image_data = Image.open(image_path)
# 转换格式
if image_data.mode != "RGB":
image_data = image_data.convert("RGB")
# 根据图片格式参数改变图片格式
if image_format == "cv":
image_data = np.array(image_data, dtype="uint8")
return image_data