lili_code/datasets/base_dataset.py

87 lines
2.7 KiB
Python

import os.path as osp
import os
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import torchvision
import utils.transforms as tf
from .registry import DATASETS
@DATASETS.register_module
class BaseDataset(Dataset):
def __init__(self, img_path, data_list, list_path='list', cfg=None):
self.cfg = cfg
self.img_path = img_path
self.list_path = osp.join(img_path, list_path)
self.data_list = data_list
self.is_training = ('train' in data_list)
self.img_name_list = []
self.full_img_path_list = []
self.label_list = []
self.exist_list = []
self.transform = self.transform_train() if self.is_training else self.transform_val()
self.init()
def transform_train(self):
raise NotImplementedError()
def transform_val(self):
val_transform = torchvision.transforms.Compose([
tf.SampleResize((self.cfg.img_width, self.cfg.img_height)),
tf.GroupNormalize(mean=(self.cfg.img_norm['mean'], (0, )), std=(
self.cfg.img_norm['std'], (1, ))),
])
return val_transform
def view(self, img, coords, file_path=None):
for coord in coords:
for x, y in coord:
if x <= 0 or y <= 0:
continue
x, y = int(x), int(y)
cv2.circle(img, (x, y), 4, (255, 0, 0), 2)
if file_path is not None:
if not os.path.exists(osp.dirname(file_path)):
os.makedirs(osp.dirname(file_path))
cv2.imwrite(file_path, img)
def init(self):
raise NotImplementedError()
def __len__(self):
return len(self.full_img_path_list)
def __getitem__(self, idx):
img = cv2.imread(self.full_img_path_list[idx]).astype(np.float32)
img = img[self.cfg.cut_height:, :, :]
if self.is_training:
label = cv2.imread(self.label_list[idx], cv2.IMREAD_UNCHANGED)
if len(label.shape) > 2:
label = label[:, :, 0]
label = label.squeeze()
label = label[self.cfg.cut_height:, :]
exist = self.exist_list[idx]
if self.transform:
img, label = self.transform((img, label))
label = torch.from_numpy(label).contiguous().long()
else:
img, = self.transform((img,))
img = torch.from_numpy(img).permute(2, 0, 1).contiguous().float()
meta = {'full_img_path': self.full_img_path_list[idx],
'img_name': self.img_name_list[idx]}
data = {'img': img, 'meta': meta}
if self.is_training:
data.update({'label': label, 'exist': exist})
return data