32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
|
|
|
|
import torch
|
|
from torchvision import datasets, transforms
|
|
|
|
def get_dataset(dir, name):
|
|
|
|
if name=='mnist':
|
|
train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
|
|
eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
|
|
elif name=='btmd':
|
|
train_dataset = datasets.BTMD(dir, train=True, download=True, transform=transforms.ToTensor())
|
|
eval_dataset = datasets.BTMD(dir, train=False, transform=transforms.ToTensor())
|
|
elif name=='cifar':
|
|
transform_train = transforms.Compose([
|
|
transforms.RandomCrop(32, padding=4),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
|
])
|
|
|
|
transform_test = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
|
])
|
|
|
|
train_dataset = datasets.CIFAR10(dir, train=True, download=True,
|
|
transform=transform_train)
|
|
eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
|
|
|
|
|
|
return train_dataset, eval_dataset |