add file
This commit is contained in:
parent
a3b9b6a102
commit
eeaf19c4e5
234
utils/detection_tool.py
Normal file
234
utils/detection_tool.py
Normal file
@ -0,0 +1,234 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from config import Config
|
||||
from ultralytics import YOLO
|
||||
|
||||
|
||||
# 检测工具类
|
||||
class DetectionTool(object):
|
||||
|
||||
def __init__(self, config: Config = None):
|
||||
# 设置基本信息
|
||||
if config is None:
|
||||
self.config = Config()
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
# 创建预测模型
|
||||
self.model = YOLO(self.config.weights_path)
|
||||
|
||||
# 预测图片的主方法
|
||||
def detect_image(self, image_data, image_format="cv", only_return_result=False):
|
||||
"""
|
||||
:param image_data: 图片数据(需要为numpy数组格式)
|
||||
:param image_format: 需要导出的图片格式(支持pillow/cv)
|
||||
:param only_return_result: 只返回预测结果
|
||||
:return: 标注好的图片
|
||||
"""
|
||||
# 获取预测结果
|
||||
pred_result = self.model(image_data, verbose=False, iou=0.35)[0].boxes
|
||||
|
||||
# 分解预测结果
|
||||
pred_boxes, pred_classes, pred_scores = pred_result.xyxy, pred_result.cls, pred_result.conf
|
||||
result = {
|
||||
"boxes": pred_boxes.cpu().numpy(),
|
||||
"classes": pred_classes.cpu().numpy(),
|
||||
"scores": pred_scores.cpu().numpy()
|
||||
}
|
||||
|
||||
# 返回预测结果
|
||||
if only_return_result:
|
||||
return result
|
||||
|
||||
# 如果目标框数量大于0
|
||||
if len(pred_boxes) > 0:
|
||||
# 将矩阵图片转换为Pillow对象
|
||||
image_data = Image.fromarray(image_data)
|
||||
|
||||
# 设置字体
|
||||
font = ImageFont.truetype(
|
||||
font=self.config.font_path,
|
||||
size=np.floor(3e-2 * image_data.size[1] + 0.5).astype('int32')
|
||||
)
|
||||
thickness = max((image_data.size[0] + image_data.size[1]) // 300, 1)
|
||||
|
||||
# 创建画板
|
||||
image_draw = ImageDraw.Draw(image_data)
|
||||
|
||||
# 遍历所预测的标注框信息
|
||||
for i, bndbox in enumerate(pred_boxes):
|
||||
# 获取类别和得分
|
||||
class_name = self.config.classes_name[int(pred_classes[i])]
|
||||
|
||||
# 获取切割坐标
|
||||
x_min, y_min, x_max, y_max = bndbox.cpu().numpy().astype("int32")
|
||||
|
||||
# 设置文本内容
|
||||
text_content = f"{class_name}||{format(float(pred_scores[i]), '.2f')}"
|
||||
|
||||
# 画框框
|
||||
for line_num in range(thickness):
|
||||
image_draw.rectangle(
|
||||
[x_min + line_num, y_min + line_num, x_max - line_num, y_max - line_num],
|
||||
outline=(255, 0, 0)
|
||||
)
|
||||
text_size = image_draw.textsize(text_content, font)
|
||||
text_content = text_content.encode('utf-8')
|
||||
|
||||
# 计算文本的坐标
|
||||
text_y_min = y_min - text_size[1]
|
||||
text_x_min = x_min
|
||||
text_y_max = y_min
|
||||
text_x_max = x_min + text_size[0]
|
||||
|
||||
# 添加文本内容
|
||||
image_draw.rectangle(
|
||||
[(text_x_min, text_y_min), (text_x_max, text_y_max)],
|
||||
fill=(255, 0, 0)
|
||||
)
|
||||
image_draw.text((text_x_min, text_y_min), str(text_content, "utf-8"), fill=(0, 0, 0), font=font)
|
||||
|
||||
# 如果选择格式为cv, 则进行转换格式
|
||||
if image_format == "cv":
|
||||
image_data = np.asarray(image_data, dtype="uint8")
|
||||
|
||||
else:
|
||||
# 如果选择格式为pillow, 则进行转换格式
|
||||
if image_format == "pillow":
|
||||
image_data = Image.fromarray(image_data)
|
||||
|
||||
return image_data, result
|
||||
|
||||
# 测试对图片进行标注
|
||||
def test_image(self):
|
||||
# 开启预测
|
||||
while True:
|
||||
# 输入图片路径
|
||||
image_path = Path(input("请输入图片的路径:"))
|
||||
|
||||
# 读取图片
|
||||
image_data = self.load_image(str(image_path))
|
||||
|
||||
# 预测
|
||||
image_data, _ = self.detect_image(image_data, "pillow")
|
||||
|
||||
# 展示
|
||||
image_data.show("images_save-pic")
|
||||
|
||||
# 是否继续
|
||||
is_next = input("是否继续?输入(y/n):")
|
||||
if is_next == "n":
|
||||
break
|
||||
|
||||
def test_video(self):
|
||||
# 输入视频路径
|
||||
video_path = Path(str(input("请输入视频的路径:")))
|
||||
|
||||
# 输入保存的路径
|
||||
video_save = Path(str(input("请输入保存的文件夹路径:")))
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
|
||||
# 获取视频参数
|
||||
frame_speed = int(cap.get(5))
|
||||
num_frame = int(cap.get(7))
|
||||
width = int(cap.get(3))
|
||||
height = int(cap.get(4))
|
||||
|
||||
# 设置导出视频参数
|
||||
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
||||
out = cv2.VideoWriter(str(video_save.joinpath(f"{video_path.name[:-4]}.mp4")), fourcc, frame_speed,
|
||||
(width, height), True)
|
||||
|
||||
count = 0
|
||||
|
||||
while cap.isOpened():
|
||||
# 初始时间记录
|
||||
start_time = datetime.now()
|
||||
|
||||
# 读取一帧数据
|
||||
ret, frame_image = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 进行标注
|
||||
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_BGR2RGB)
|
||||
frame_image, _ = self.detect_image(frame_image)
|
||||
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 写入视频
|
||||
out.write(frame_image)
|
||||
count += 1
|
||||
|
||||
# 结束时间
|
||||
end_time = datetime.now()
|
||||
|
||||
# 剩余时间
|
||||
need_time = (end_time - start_time).microseconds / 1000000 * (num_frame - count)
|
||||
fps = 1000 / ((end_time - start_time).microseconds / 1000)
|
||||
|
||||
print(
|
||||
f"\r正在渲染: {count}/{num_frame}|FPS:{fps}|大约剩余时间:{format(need_time, '.0f')}s",
|
||||
end="",
|
||||
flush=True
|
||||
)
|
||||
|
||||
out.release()
|
||||
print("\n渲染完成")
|
||||
|
||||
def test_camera(self, cap_index=0):
|
||||
# 选择数据驱动来源
|
||||
cap = cv2.VideoCapture(cap_index)
|
||||
|
||||
while cap.isOpened():
|
||||
ret, origin_image = cap.read() # 读取一帧数据
|
||||
if not ret:
|
||||
break
|
||||
|
||||
# 记录开始时间
|
||||
begin_time = datetime.now()
|
||||
|
||||
# 预测
|
||||
frame_image = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
|
||||
frame_image, _ = self.detect_image(frame_image)
|
||||
frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 记录结束时间
|
||||
end_time = begin_time.now()
|
||||
|
||||
# 输出FPS
|
||||
fps = 1000 / ((end_time - begin_time).microseconds / 1000)
|
||||
print(f"\rFPS:{fps}", flush=True, end="")
|
||||
|
||||
# 显示图像
|
||||
control = cv2.waitKey(1) # 1ms后切换下一帧图像
|
||||
if control & 0xFF == ord('q'):
|
||||
# 释放摄像头并销毁所有窗口
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
break
|
||||
cv2.imshow("predict", frame_image)
|
||||
|
||||
@staticmethod
|
||||
# 读取图片
|
||||
def load_image(image_path, image_format="cv"):
|
||||
"""
|
||||
:param image_path: 图片路径
|
||||
:param image_format: 需要导出的图片格式(支持pillow/cv)
|
||||
:return: 返回处理好的图片数据
|
||||
"""
|
||||
# 读取
|
||||
image_data = Image.open(image_path)
|
||||
|
||||
# 转换格式
|
||||
if image_data.mode != "RGB":
|
||||
image_data = image_data.convert("RGB")
|
||||
|
||||
# 根据图片格式参数改变图片格式
|
||||
if image_format == "cv":
|
||||
image_data = np.array(image_data, dtype="uint8")
|
||||
|
||||
return image_data
|
||||
Loading…
Reference in New Issue
Block a user