from torch.utils.data import Dataset import PIL.Image as Image import os # def make_dataset(root): # # root = "./data/train" # imgs = [] # ori_path = os.path.join(root, "Data") # ground_path = os.path.join(root, "Ground") # names = os.listdir(ori_path) # n = len(names) # for i in range(n): # img = os.path.join(ori_path, names[i]) # mask = os.path.join(ground_path, names[i]) # imgs.append((img, mask)) # return imgs class LiverDataset(Dataset): def __init__(self, root, transform=None, target_transform=None): imgs = "/home/shared/wy/flask_web/Unet_liver_seg-master/data/val/Data/P2_T1_00018.png" self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): # x_path, y_path = self.imgs[index] x_path = self.imgs img_x = Image.open(x_path).convert('L') # img_y = Image.open(y_path).convert('L') if self.transform is not None: img_x = self.transform(img_x) # if self.target_transform is not None: # img_y = self.target_transform(img_y) return img_x def __len__(self): return len(self.imgs)