fujie_code/utils/dataloader.py

171 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class YoloDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train):
super(YoloDataset, self).__init__()
self.annotation_lines = annotation_lines # 记录训练集或测试集的文件的路径,这个是可以全部载入的
self.input_shape = input_shape # 这里是 [416, 416]
self.num_classes = num_classes # 这里是20
self.length = len(self.annotation_lines) # 数据的数量
self.train = train # 是否是训练集的标记
def __len__(self):
return self.length
def __getitem__(self, index):
index = index % self.length
# ---------------------------------------------------#
# 训练时进行数据的随机增强
# 验证时不进行数据的随机增强
# ---------------------------------------------------#
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2],
random=self.train) # 自定义的数据增强
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1)) # 像素值归到0~1之间然后变换坐标轴
box = np.array(box, dtype=np.float32) # 转为numpy。np中常用的是创建新类型的array。
if len(box) != 0:
box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1] # 把框的坐标归一化
box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
box[:, 2:4] = box[:, 2:4] - box[:, 0:2] # box第01维记录中心点 box第23维记录宽高
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2 # box第01维记录中心点
return image, box
def rand(self, a=0, b=1):
return np.random.rand() * (b - a) + a
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
line = annotation_line.split() # 以空格、回车等分隔字符串
# ------------------------------#
# 读取图像并转换成RGB图像
# ------------------------------#
image = Image.open(line[0]) # line[0] 是图片的地址
image = cvtColor(image) # 这里啥也没干
# ------------------------------#
# 获得图像的高宽与目标高宽
# ------------------------------#
iw, ih = image.size # 获取图像的原始尺寸
h, w = input_shape
# ------------------------------#
# 获得预测框
# ------------------------------#
box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]]) # 从python二维矩阵转到 numpy二维矩阵
if not random: # 没进入这里面
scale = min(w / iw, h / ih)
nw = int(iw * scale)
nh = int(ih * scale)
dx = (w - nw) // 2
dy = (h - nh) // 2
# ---------------------------------#
# 将图像多余的部分加上灰条
# ---------------------------------#
image = image.resize((nw, nh), Image.BICUBIC)
new_image = Image.new('RGB', (w, h), (128, 128, 128))
new_image.paste(image, (dx, dy))
image_data = np.array(new_image, np.float32)
# ---------------------------------#
# 对真实框进行调整
# ---------------------------------#
if len(box) > 0:
np.random.shuffle(box)
box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
box[:, 0:2][box[:, 0:2] < 0] = 0
box[:, 2][box[:, 2] > w] = w
box[:, 3][box[:, 3] > h] = h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
return image_data, box
# ------------------------------------------#
# 对原始图像进行缩放并且进行长和宽的扭曲
# ------------------------------------------#
new_ar = iw / ih * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter) # (iw*随机) / (ih*随机)
scale = self.rand(.25, 2) # 随机一个缩放比例
if new_ar < 1: # 原图高大
nh = int(scale * h) # 新图先缩放高
nw = int(nh * new_ar)
else: # 原图宽大
nw = int(scale * w) # 新的宽从 预期宽中 乘以随机的比例
nh = int(nw / new_ar) # 新的宽、高比,也是 new_ar, 也就是也是宽大
image = image.resize((nw, nh), Image.BICUBIC)
# ------------------------------------------#
# 将图像多余的部分加上灰条
# ------------------------------------------#
dx = int(self.rand(0, w - nw)) # 在(0, w - nw)找一个点作为新图的放置点
dy = int(self.rand(0, h - nh))
new_image = Image.new('RGB', (w, h), (128, 128, 128)) # 画一个 412, 412大小的灰图
new_image.paste(image, (dx, dy)) # 在这里看看两者的区别
image = new_image
# ------------------------------------------#
# 翻转图像
# ------------------------------------------#
flip = self.rand() < .5
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
image_data = np.array(image, np.uint8)
# ---------------------------------#
# 对图像进行色域变换
# 计算色域变换的参数
# ---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
# ---------------------------------#
# 将图像转到HSV上
# ---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
dtype = image_data.dtype
# ---------------------------------#
# 应用变换
# ---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
# LUT是look-up table查找表的意思,cv2.LUT(src, lut, dst=None)的作用是对输入的src执行查找表lut转换
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB) # image_data在这里还是unit8类型
# ---------------------------------#
# 对真实框进行调整
# ---------------------------------#
if len(box) > 0: # 如果有box
np.random.shuffle(box)
box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx # 所有行的第0列和2列也就是 x 坐标, 除以iw找到占原图的比例再乘以nw是新图的比例再加dx是新图中的偏移
box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
if flip:
box[:, [0, 2]] = w - box[:, [2, 0]] # 如果有水平翻转则x坐标变换为416-x并且x0 和 x1的位置互换一下
box[:, 0:2][box[:, 0:2] < 0] = 0 # 对于左上角的点在图像外小于0则把对应的位置的坐标置为0 # 右下角的点不会小于0吗
box[:, 2][box[:, 2] > w] = w # 对于右下角的横坐标点超出图的则置为w # 右下角不会超出图吗?
box[:, 3][box[:, 3] > h] = h # 对于右下角的纵坐标点超出图的则置为h
box_w = box[:, 2] - box[:, 0]
box_h = box[:, 3] - box[:, 1]
box = box[np.logical_and(box_w > 1, box_h > 1)] # 多余的检查如果宽、高大于至少1则保留下来
return image_data, box # box依然是左上角和右下角的形式
# DataLoader中collate_fn使用
def yolo_dataset_collate(batch):
images = [] # 这是是一个batch大小的列表每一项是 image_data, box。需要把image放一堆box放一堆
bboxes = []
for img, box in batch:
images.append(img) # images在这里已经是0~1的float32类型了
bboxes.append(box)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor) # 转换为 batch_size, C, H, W 的数据
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes] # 转换为一个列表每个元素是一组二维Tensor
return images, bboxes