42 lines
1.8 KiB
Python
42 lines
1.8 KiB
Python
import torch
|
||
import torch.nn.functional as F
|
||
import numpy as np
|
||
import albumentations as A
|
||
|
||
# boxes = (cls, x, y, w, h)
|
||
def horizontal_flip(images, boxes):
|
||
images = np.flip(images, [-1])
|
||
boxes[:, 1] = 1 - boxes[:, 1]
|
||
return images, boxes
|
||
|
||
# images[np.unit8], boxes[numpy] = (cls, x, y, w, h)
|
||
def augment(image, boxes):
|
||
h, w, _ = image.shape
|
||
labels, boxes_coord = boxes[:, 0], boxes[:, 1:]
|
||
labels = labels.tolist()
|
||
boxes_coord = boxes_coord * h # 得到原图尺寸下的坐标(未归一化的坐标)
|
||
boxes_coord[:, 0] = np.clip(boxes_coord[:, 0]-boxes_coord[:, 2]/2, a_min=0, a_max=None) # 确保x_min和y_min有效
|
||
boxes_coord[:, 1] = np.clip(boxes_coord[:, 1]-boxes_coord[:, 3]/2, a_min=0, a_max=None)
|
||
boxes_coord = boxes_coord.tolist() # [x_min, y_min, width, height]
|
||
|
||
# 在这里设置数据增强的方法
|
||
aug = A.Compose([
|
||
A.HorizontalFlip(p=0.5),
|
||
# A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.5),
|
||
# A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=5, border_mode=0, p=0.5)
|
||
], bbox_params={'format':'coco', 'label_fields': ['category_id']})
|
||
|
||
augmented = aug(image=image, bboxes=boxes_coord, category_id=labels)
|
||
|
||
# 经过aug之后,如果把boxes变没了,则返回原来的图片
|
||
if augmented['bboxes']:
|
||
image = augmented['image']
|
||
|
||
boxes_coord = np.array(augmented['bboxes']) # x_min, y_min, w, h → x, y, w, h
|
||
boxes_coord[:, 0] = boxes_coord[:, 0] + boxes_coord[:, 2]/2
|
||
boxes_coord[:, 1] = boxes_coord[:, 1] + boxes_coord[:, 3]/2
|
||
boxes_coord = boxes_coord / h
|
||
labels = np.array(augmented['category_id'])[:, None]
|
||
boxes = np.concatenate((labels, boxes_coord), 1)
|
||
|
||
return image, boxes |