133 lines
4.5 KiB
Python
133 lines
4.5 KiB
Python
import os
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader, random_split
|
|
from torchvision import datasets, transforms
|
|
import pytorch_lightning as pl
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
from hydra import config_class
|
|
import hydra
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
class MNISTDataModule(pl.LightningDataModule):
|
|
def __init__(self, data_dir: str = "./", batch_size: int = 32):
|
|
super().__init__()
|
|
self.data_dir = data_dir
|
|
self.batch_size = batch_size
|
|
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
|
|
|
|
def prepare_data(self):
|
|
# Download only once
|
|
datasets.MNIST(self.data_dir, train=True, download=True)
|
|
datasets.MNIST(self.data_dir, train=False, download=True)
|
|
|
|
def setup(self, stage=None):
|
|
# Split dataset between train and val
|
|
if stage == 'fit' or stage is None:
|
|
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
|
|
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
|
|
# Load test dataset for test stage
|
|
if stage == 'test' or stage is None:
|
|
self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(self.mnist_train, batch_size=self.batch_size)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(self.mnist_val, batch_size=self.batch_size)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(self.mnist_test, batch_size=self.batch_size)
|
|
|
|
|
|
class LitModel(pl.LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer_1 = nn.Linear(28 * 28, 128)
|
|
self.layer_2 = nn.Linear(128, 256)
|
|
self.layer_3 = nn.Linear(256, 10)
|
|
|
|
def forward(self, x):
|
|
batch_size, channels, width, height = x.size()
|
|
x = x.view(batch_size, -1) # Flatten the input
|
|
x = torch.relu(self.layer_1(x))
|
|
x = torch.relu(self.layer_2(x))
|
|
x = self.layer_3(x)
|
|
return x
|
|
|
|
# Define the optimizer to be used here
|
|
def configure_optimizers(self):
|
|
return torch.optim.Adam(self.parameters(), lr=1e-3)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = nn.functional.cross_entropy(logits, y)
|
|
self.log("train_loss", loss)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = nn.functional.cross_entropy(logits, y)
|
|
self.log("val_loss", loss)
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
x, y = batch
|
|
logits = self(x)
|
|
loss = nn.functional.cross_entropy(logits, y)
|
|
self.log("test_loss", loss)
|
|
|
|
|
|
@hydra.main(config_path="config", config_name="config", version_base='1.1')
|
|
def main(cfg: config_class.Config):
|
|
# cfg = OmegaConf.to_yaml(cfg)
|
|
cfg_dict = OmegaConf.to_container(cfg, resolve=True) # Convert to a native Python dict
|
|
print(cfg_dict['model']['model']['hidden_layers'][0])
|
|
|
|
# Initialize the data module and model
|
|
data_module = MNISTDataModule()
|
|
model = LitModel()
|
|
|
|
# Initialize a trainer
|
|
trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=-1,
|
|
callbacks=[ModelCheckpoint(dirpath='./checkpoints/', save_top_k=1, monitor='val_loss')])
|
|
|
|
# Train the model
|
|
trainer.fit(model, data_module)
|
|
|
|
# Eval model
|
|
trainer.test(datamodule=data_module)
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|
|
# # Inference
|
|
# model_path = "path/to/your_model.ckpt" # Update this path to your actual model checkpoint path
|
|
# trained_model = LitModel.load_from_checkpoint(checkpoint_path=model_path)
|
|
# trained_model.eval()
|
|
# trained_model.freeze() # Optional in PyTorch Lightning to prepare the model for exporting or making it ready for inference
|
|
# from PIL import Image
|
|
# # Load an image
|
|
# image_path = "path/to/your_image.png"
|
|
# image = Image.open(image_path).convert("L") # Convert to grayscale
|
|
|
|
# # Transform the image to tensor
|
|
# transform = transforms.Compose([
|
|
# transforms.Resize((28, 28)), # Resize to the same size as MNIST images
|
|
# transforms.ToTensor(),
|
|
# transforms.Normalize((0.1307,), (0.3081,)) # Same normalization as during training
|
|
# ])
|
|
|
|
# image = transform(image).unsqueeze(0) # Add batch dimension
|
|
|
|
# # Make a prediction
|
|
# with torch.no_grad(): # Disable gradient computation for inference
|
|
# prediction = trained_model(image)
|
|
|
|
# predicted_label = prediction.argmax(dim=1)
|
|
# print(f"Predicted Label: {predicted_label.item()}")
|