38 lines
1.5 KiB
Python
38 lines
1.5 KiB
Python
from recipe1m import Recipe1M
|
|
|
|
import pytorch_lightning as pl
|
|
from torch.utils.data import DataLoader, random_split
|
|
import transforms
|
|
import torchvision.datasets as datasets
|
|
|
|
class Recipe1MDataModule(pl.LightningDataModule):
|
|
def __init__(self, data_dir: str = "./", batch_size: int = 32):
|
|
super().__init__()
|
|
self.data_dir = data_dir
|
|
self.batch_size = batch_size
|
|
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
# Prepare dataset
|
|
|
|
|
|
def prepare_data(self):
|
|
# Download only once
|
|
datasets.MNIST(self.data_dir, train=True, download=True)
|
|
datasets.MNIST(self.data_dir, train=False, download=True)
|
|
|
|
def setup(self, stage=None):
|
|
# Split dataset between train and val
|
|
if stage == 'fit' or stage is None:
|
|
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
|
|
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
|
# Load test dataset for test stage
|
|
if stage == 'test' or stage is None:
|
|
self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(self.mnist_train, batch_size=self.batch_size)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(self.mnist_val, batch_size=self.batch_size)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(self.mnist_test, batch_size=self.batch_size) |