Graduation_Project/QN/RecipeRetrieval/train.py

133 lines
4.5 KiB
Python

import os
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from hydra import config_class
import hydra
from omegaconf import OmegaConf
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# Download only once
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Split dataset between train and val
if stage == 'fit' or stage is None:
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Load test dataset for test stage
if stage == 'test' or stage is None:
self.mnist_test = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(28 * 28, 128)
self.layer_2 = nn.Linear(128, 256)
self.layer_3 = nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
x = x.view(batch_size, -1) # Flatten the input
x = torch.relu(self.layer_1(x))
x = torch.relu(self.layer_2(x))
x = self.layer_3(x)
return x
# Define the optimizer to be used here
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
self.log("val_loss", loss)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = nn.functional.cross_entropy(logits, y)
self.log("test_loss", loss)
@hydra.main(config_path="config", config_name="config", version_base='1.1')
def main(cfg: config_class.Config):
# cfg = OmegaConf.to_yaml(cfg)
cfg_dict = OmegaConf.to_container(cfg, resolve=True) # Convert to a native Python dict
print(cfg_dict['model']['model']['hidden_layers'][0])
# Initialize the data module and model
data_module = MNISTDataModule()
model = LitModel()
# Initialize a trainer
trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=-1,
callbacks=[ModelCheckpoint(dirpath='./checkpoints/', save_top_k=1, monitor='val_loss')])
# Train the model
trainer.fit(model, data_module)
# Eval model
trainer.test(datamodule=data_module)
if __name__ == '__main__':
main()
# # Inference
# model_path = "path/to/your_model.ckpt" # Update this path to your actual model checkpoint path
# trained_model = LitModel.load_from_checkpoint(checkpoint_path=model_path)
# trained_model.eval()
# trained_model.freeze() # Optional in PyTorch Lightning to prepare the model for exporting or making it ready for inference
# from PIL import Image
# # Load an image
# image_path = "path/to/your_image.png"
# image = Image.open(image_path).convert("L") # Convert to grayscale
# # Transform the image to tensor
# transform = transforms.Compose([
# transforms.Resize((28, 28)), # Resize to the same size as MNIST images
# transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,)) # Same normalization as during training
# ])
# image = transform(image).unsqueeze(0) # Add batch dimension
# # Make a prediction
# with torch.no_grad(): # Disable gradient computation for inference
# prediction = trained_model(image)
# predicted_label = prediction.argmax(dim=1)
# print(f"Predicted Label: {predicted_label.item()}")