Graduation_Project/QN/RecipeRetrieval/dataset/factory.py

38 lines
1.5 KiB
Python

from recipe1m import Recipe1M
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
import transforms
import torchvision.datasets as datasets
class Recipe1MDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# Prepare dataset
def prepare_data(self):
# Download only once
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Split dataset between train and val
if stage == 'fit' or stage is None:
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Load test dataset for test stage
if stage == 'test' or stage is None:
self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)