233 lines
13 KiB
Python
233 lines
13 KiB
Python
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
from torchvision.ops import nms
|
|||
|
import numpy as np
|
|||
|
|
|||
|
|
|||
|
class DecodeBox():
|
|||
|
def __init__(self, anchors, num_classes, input_shape, anchors_mask=[[6, 7, 8], [3, 4, 5], [0, 1, 2]]):
|
|||
|
super(DecodeBox, self).__init__()
|
|||
|
self.anchors = anchors
|
|||
|
self.num_classes = num_classes
|
|||
|
self.bbox_attrs = 5 + num_classes
|
|||
|
self.input_shape = input_shape
|
|||
|
# -----------------------------------------------------------#
|
|||
|
# 13x13的特征层对应的anchor是[116,90],[156,198],[373,326]
|
|||
|
# 26x26的特征层对应的anchor是[30,61],[62,45],[59,119]
|
|||
|
# 52x52的特征层对应的anchor是[10,13],[16,30],[33,23]
|
|||
|
# -----------------------------------------------------------#
|
|||
|
self.anchors_mask = anchors_mask
|
|||
|
|
|||
|
def decode_box(self, inputs):
|
|||
|
outputs = []
|
|||
|
for i, input in enumerate(inputs):
|
|||
|
# -----------------------------------------------#
|
|||
|
# 输入的input一共有三个,他们的shape分别是
|
|||
|
# batch_size, 255, 13, 13
|
|||
|
# batch_size, 255, 26, 26
|
|||
|
# batch_size, 255, 52, 52
|
|||
|
# -----------------------------------------------#
|
|||
|
batch_size = input.size(0)
|
|||
|
input_height = input.size(2)
|
|||
|
input_width = input.size(3)
|
|||
|
|
|||
|
# -----------------------------------------------#
|
|||
|
# 输入为416x416时
|
|||
|
# stride_h = stride_w = 32、16、8
|
|||
|
# -----------------------------------------------#
|
|||
|
stride_h = self.input_shape[0] / input_height
|
|||
|
stride_w = self.input_shape[1] / input_width
|
|||
|
# -------------------------------------------------#
|
|||
|
# 此时获得的scaled_anchors大小是相对于特征层的
|
|||
|
# -------------------------------------------------#
|
|||
|
scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in
|
|||
|
self.anchors[self.anchors_mask[i]]]
|
|||
|
|
|||
|
# -----------------------------------------------#
|
|||
|
# 输入的input一共有三个,他们的shape分别是
|
|||
|
# batch_size, 3, 13, 13, 85
|
|||
|
# batch_size, 3, 26, 26, 85
|
|||
|
# batch_size, 3, 52, 52, 85
|
|||
|
# -----------------------------------------------#
|
|||
|
prediction = input.view(batch_size, len(self.anchors_mask[i]),
|
|||
|
self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
|
|||
|
# 调整为 1,3,13,13,25 的形状
|
|||
|
# -----------------------------------------------#
|
|||
|
# 先验框的中心位置的调整参数
|
|||
|
# -----------------------------------------------#
|
|||
|
x = torch.sigmoid(prediction[..., 0])
|
|||
|
y = torch.sigmoid(prediction[..., 1])
|
|||
|
# -----------------------------------------------#
|
|||
|
# 先验框的宽高调整参数
|
|||
|
# -----------------------------------------------#
|
|||
|
w = prediction[..., 2]
|
|||
|
h = prediction[..., 3]
|
|||
|
# -----------------------------------------------#
|
|||
|
# 获得置信度,是否有物体
|
|||
|
# -----------------------------------------------#
|
|||
|
conf = torch.sigmoid(prediction[..., 4])
|
|||
|
# -----------------------------------------------#
|
|||
|
# 种类置信度
|
|||
|
# -----------------------------------------------#
|
|||
|
pred_cls = torch.sigmoid(prediction[..., 5:])
|
|||
|
|
|||
|
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
|
|||
|
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
|
|||
|
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 生成网格,先验框中心,网格左上角
|
|||
|
# batch_size,3,13,13
|
|||
|
# ----------------------------------------------------------#
|
|||
|
grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
|
|||
|
batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
|
|||
|
grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
|
|||
|
batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
|
|||
|
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 按照网格格式生成先验框的宽高
|
|||
|
# batch_size,3,13,13
|
|||
|
# ----------------------------------------------------------#
|
|||
|
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
|
|||
|
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
|
|||
|
anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
|
|||
|
anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
|
|||
|
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 利用预测结果对先验框进行调整
|
|||
|
# 首先调整先验框的中心,从先验框中心向右下角偏移 # ?从先验框左上角向右下角偏移?
|
|||
|
# 再调整先验框的宽高。
|
|||
|
# ----------------------------------------------------------#
|
|||
|
pred_boxes = FloatTensor(prediction[..., :4].shape)
|
|||
|
pred_boxes[..., 0] = x.data + grid_x
|
|||
|
pred_boxes[..., 1] = y.data + grid_y
|
|||
|
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
|
|||
|
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
|
|||
|
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 将输出结果归一化成小数的形式
|
|||
|
# ----------------------------------------------------------#
|
|||
|
_scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
|
|||
|
output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
|
|||
|
conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
|
|||
|
# output的shape是 batch_size, -1, attr(25)
|
|||
|
outputs.append(output.data)
|
|||
|
return outputs
|
|||
|
|
|||
|
def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
|
|||
|
# -----------------------------------------------------------------#
|
|||
|
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘
|
|||
|
# -----------------------------------------------------------------#
|
|||
|
box_yx = box_xy[..., ::-1]
|
|||
|
box_hw = box_wh[..., ::-1]
|
|||
|
input_shape = np.array(input_shape)
|
|||
|
image_shape = np.array(image_shape)
|
|||
|
|
|||
|
if letterbox_image:
|
|||
|
# -----------------------------------------------------------------#
|
|||
|
# 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
|
|||
|
# new_shape指的是宽高缩放情况
|
|||
|
# -----------------------------------------------------------------#
|
|||
|
new_shape = np.round(image_shape * np.min(input_shape / image_shape))
|
|||
|
offset = (input_shape - new_shape) / 2. / input_shape
|
|||
|
scale = input_shape / new_shape
|
|||
|
|
|||
|
box_yx = (box_yx - offset) * scale
|
|||
|
box_hw *= scale
|
|||
|
|
|||
|
box_mins = box_yx - (box_hw / 2.)
|
|||
|
box_maxes = box_yx + (box_hw / 2.)
|
|||
|
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]],
|
|||
|
axis=-1)
|
|||
|
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
|
|||
|
return boxes
|
|||
|
|
|||
|
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5,
|
|||
|
nms_thres=0.4):
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 将预测结果的格式转换成左上角右下角的格式。
|
|||
|
# prediction [batch_size, num_anchors, 85]
|
|||
|
# ----------------------------------------------------------#
|
|||
|
box_corner = prediction.new(prediction.shape)
|
|||
|
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
|||
|
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
|||
|
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
|||
|
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
|||
|
prediction[:, :, :4] = box_corner[:, :, :4]
|
|||
|
|
|||
|
output = [None for _ in range(len(prediction))]
|
|||
|
for i, image_pred in enumerate(prediction):
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 对种类预测部分取max。 # image_pred 是在prediction中以0维度迭代
|
|||
|
# class_conf [num_anchors, 1] 种类置信度
|
|||
|
# class_pred [num_anchors, 1] 种类 image_pred[:, 5:5 + num_classes] 是取出类别
|
|||
|
# ----------------------------------------------------------#
|
|||
|
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
|
|||
|
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 利用置信度进行第一轮筛选
|
|||
|
# ----------------------------------------------------------#
|
|||
|
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
|
|||
|
|
|||
|
# ----------------------------------------------------------#
|
|||
|
# 根据置信度进行预测结果的筛选
|
|||
|
# ----------------------------------------------------------#
|
|||
|
image_pred = image_pred[conf_mask]
|
|||
|
class_conf = class_conf[conf_mask]
|
|||
|
class_pred = class_pred[conf_mask]
|
|||
|
if not image_pred.size(0):
|
|||
|
continue # 如果没有剩下类别,就判断下一张图片
|
|||
|
# -------------------------------------------------------------------------#
|
|||
|
# detections [num_anchors, 7]
|
|||
|
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
|
|||
|
# -------------------------------------------------------------------------#
|
|||
|
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
|
|||
|
|
|||
|
# ------------------------------------------#
|
|||
|
# 获得预测结果中包含的所有种类
|
|||
|
# ------------------------------------------#
|
|||
|
unique_labels = detections[:, -1].cpu().unique()
|
|||
|
|
|||
|
if prediction.is_cuda:
|
|||
|
unique_labels = unique_labels.cuda()
|
|||
|
detections = detections.cuda()
|
|||
|
|
|||
|
for c in unique_labels:
|
|||
|
# ------------------------------------------#
|
|||
|
# 获得某一类得分筛选后全部的预测结果
|
|||
|
# ------------------------------------------#
|
|||
|
detections_class = detections[detections[:, -1] == c]
|
|||
|
|
|||
|
# ------------------------------------------#
|
|||
|
# 使用官方自带的非极大抑制会速度更快一些!
|
|||
|
# ------------------------------------------#
|
|||
|
keep = nms(
|
|||
|
detections_class[:, :4],
|
|||
|
detections_class[:, 4] * detections_class[:, 5],
|
|||
|
nms_thres
|
|||
|
)
|
|||
|
max_detections = detections_class[keep]
|
|||
|
|
|||
|
# # 按照存在物体的置信度排序
|
|||
|
# _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
|
|||
|
# detections_class = detections_class[conf_sort_index]
|
|||
|
# # 进行非极大抑制
|
|||
|
# max_detections = []
|
|||
|
# while detections_class.size(0):
|
|||
|
# # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
|
|||
|
# max_detections.append(detections_class[0].unsqueeze(0))
|
|||
|
# if len(detections_class) == 1:
|
|||
|
# break
|
|||
|
# ious = bbox_iou(max_detections[-1], detections_class[1:])
|
|||
|
# detections_class = detections_class[1:][ious < nms_thres]
|
|||
|
# # 堆叠
|
|||
|
# max_detections = torch.cat(max_detections).data
|
|||
|
|
|||
|
# Add max detections to outputs
|
|||
|
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
|
|||
|
|
|||
|
if output[i] is not None:
|
|||
|
output[i] = output[i].cpu().numpy()
|
|||
|
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4]) / 2, output[i][:, 2:4] - output[i][:, 0:2]
|
|||
|
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
|
|||
|
return output
|