87 lines
2.7 KiB
Python
87 lines
2.7 KiB
Python
import os.path as osp
|
|
import os
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import torchvision
|
|
import utils.transforms as tf
|
|
from .registry import DATASETS
|
|
|
|
|
|
@DATASETS.register_module
|
|
class BaseDataset(Dataset):
|
|
def __init__(self, img_path, data_list, list_path='list', cfg=None):
|
|
self.cfg = cfg
|
|
self.img_path = img_path
|
|
self.list_path = osp.join(img_path, list_path)
|
|
self.data_list = data_list
|
|
self.is_training = ('train' in data_list)
|
|
|
|
self.img_name_list = []
|
|
self.full_img_path_list = []
|
|
self.label_list = []
|
|
self.exist_list = []
|
|
|
|
self.transform = self.transform_train() if self.is_training else self.transform_val()
|
|
|
|
self.init()
|
|
|
|
def transform_train(self):
|
|
raise NotImplementedError()
|
|
|
|
def transform_val(self):
|
|
val_transform = torchvision.transforms.Compose([
|
|
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
|
|
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
|
|
self.cfg.img_norm['std'], (1, ))),
|
|
])
|
|
return val_transform
|
|
|
|
def view(self, img, coords, file_path=None):
|
|
for coord in coords:
|
|
for x, y in coord:
|
|
if x <= 0 or y <= 0:
|
|
continue
|
|
x, y = int(x), int(y)
|
|
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
|
|
|
|
if file_path is not None:
|
|
if not os.path.exists(osp.dirname(file_path)):
|
|
os.makedirs(osp.dirname(file_path))
|
|
cv2.imwrite(file_path, img)
|
|
|
|
|
|
def init(self):
|
|
raise NotImplementedError()
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.full_img_path_list)
|
|
|
|
def __getitem__(self, idx):
|
|
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
|
|
img = img[self.cfg.cut_height:, :, :]
|
|
|
|
if self.is_training:
|
|
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
|
|
if len(label.shape) > 2:
|
|
label = label[:, :, 0]
|
|
label = label.squeeze()
|
|
label = label[self.cfg.cut_height:, :]
|
|
exist = self.exist_list[idx]
|
|
if self.transform:
|
|
img, label = self.transform((img, label))
|
|
label = torch.from_numpy(label).contiguous().long()
|
|
else:
|
|
img, = self.transform((img,))
|
|
|
|
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
|
|
meta = {'full_img_path': self.full_img_path_list[idx],
|
|
'img_name': self.img_name_list[idx]}
|
|
|
|
data = {'img': img, 'meta': meta}
|
|
if self.is_training:
|
|
data.update({'label': label, 'exist': exist})
|
|
return data
|