pytorch-templates/optuna-lightning.ipynb
2025-05-29 20:18:21 -05:00

3.9 MiB

Imports

In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import torchmetrics
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, callbacks
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.loggers import CSVLogger
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from optuna import visualization

Lightning classes and functions

In [2]:
# Lightning module. Model with its usage
class MNISTModel(LightningModule):
    def __init__(self, in_dims, num_classes, learning_rate=1e-3, hidden_size=64):
        super().__init__()

        # Set our init args as class attributes
        self.loss_fn = F.cross_entropy
        self.metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.learning_rate = learning_rate
        self.save_hyperparameters()

        # Define PyTorch model
        channels, width, height = in_dims
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.Hardswish(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.Hardswish(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        logits = self.model(x)
        return torch.softmax(logits, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        metric = self.metric(preds, y)

        self.log("train_loss", loss, prog_bar=False, on_step=False, on_epoch=True)
        self.log("train_metric", metric, prog_bar=False, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        metric = self.metric(preds, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_metric", metric, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)
        metric = self.metric(preds, y)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_metric", metric, prog_bar=True)

    def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
        x, y = batch
        logits = self.model(x)
        return torch.softmax(logits, dim=1)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer


# LightningDataModule. Data management
class MNISTData(LightningDataModule):
    def __init__(self, data_dir="./", batch_size=64, num_workers=4):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.dims = (1, 28, 28)
        self.num_classes = 10
        self.num_workers = num_workers
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def prepare_data(self):
        # download
        datasets.FashionMNIST(self.data_dir, train=True, download=True)
        datasets.FashionMNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = datasets.FashionMNIST(
                self.data_dir, train=True, transform=self.transform
            )
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage == "predict" or stage is None:
            self.mnist_test = datasets.FashionMNIST(
                self.data_dir, train=False, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(
            self.mnist_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
        )

    def val_dataloader(self):
        return DataLoader(
            self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers
        )

    def test_dataloader(self):
        return DataLoader(
            self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers
        )

    def predict_dataloader(self):
        return DataLoader(
            self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers
        )


# Plot loss and metrics after training and validation
def plot_training(trainer: Trainer):
    metrics = pd.read_csv(
        f"{trainer.logger.log_dir}/metrics.csv",
        usecols=["epoch", "val_loss", "val_metric", "train_loss", "train_metric"],
    )
    metrics.set_index("epoch", inplace=True)
    val_metrics = metrics[["val_loss", "val_metric"]].dropna(axis=0)
    train_metrics = metrics[["train_loss", "train_metric"]].dropna(axis=0)
    metrics = train_metrics.join(val_metrics)
    metrics.plot(
        subplots=[("val_loss", "train_loss"), ("val_metric", "train_metric")],
        figsize=(8, 6),
        title="Training results",
        ylabel="Value",
    )

Bayesian optimization with Optuna

In [ ]:
# Optuna objective function
def objective(trial: optuna.trial.Trial) -> float:
    # Select hyperparameters
    learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True)
    hidden_size = trial.suggest_int("hidden_size", 2, 2048, log=True)
    epochs = 5

    model = MNISTModel(fashion_mnist.dims, fashion_mnist.num_classes, learning_rate, hidden_size)
    monitor = "val_loss"

    # Create trainer
    swa_callback = callbacks.StochasticWeightAveraging(1e-2)
    optuna_callback = PyTorchLightningPruningCallback(trial, monitor=monitor)
    trainer = Trainer(
        accelerator="auto",
        devices="auto",
        max_epochs=epochs,
        callbacks=[optuna_callback, swa_callback],
        precision="16-mixed",
        log_every_n_steps=20,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
        enable_model_summary=False,
    )

    # Train model
    trainer.fit(model, fashion_mnist)
    optuna_callback.check_pruned()
    result = trainer.callback_metrics[monitor].item()

    del model
    torch.cuda.empty_cache()

    return result


# Program constants
PATH_DATASETS = "data"
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
MAX_TIME = None
N_TRIALS = 10
N_STARTUP_TRIALS = N_TRIALS // 5

# Instanciate lightning dataset
fashion_mnist = MNISTData(PATH_DATASETS, BATCH_SIZE)
fashion_mnist.prepare_data()
fashion_mnist.setup(stage="fit")

# Run optimization
sampler = optuna.samplers.TPESampler(n_startup_trials=N_STARTUP_TRIALS)
pruner = optuna.pruners.HyperbandPruner()
study = optuna.create_study(
    study_name="bayesian_optimization",
    sampler=sampler,
    pruner=pruner,
    direction="minimize",
)
study.optimize(
    objective,
    n_trials=N_TRIALS,
    timeout=MAX_TIME,
    n_jobs=1,
    catch=(RuntimeError, AssertionError),
    gc_after_trial=True,
    show_progress_bar=True,
)

# Save optimization results
study_df = study.trials_dataframe()
study_df.to_csv("study_results.csv", index=False)

# Print optimization results
best_params = study.best_params
print("Best params:")
print(best_params)
print("Best loss:")
print(study.best_value)

Optimization visualization

In [4]:
visualization.plot_optimization_history(study)
In [5]:
visualization.plot_intermediate_values(study)
In [6]:
visualization.plot_parallel_coordinate(study)
In [7]:
visualization.plot_contour(study)
In [8]:
visualization.plot_param_importances(study)
In [9]:
visualization.plot_edf(study)
In [10]:
visualization.plot_rank(study)
/tmp/ipykernel_3673/1843785596.py:1: ExperimentalWarning:

plot_rank is experimental (supported from v3.2.0). The interface can change in the future.

In [11]:
visualization.plot_slice(study)
In [12]:
visualization.plot_timeline(study)
/tmp/ipykernel_3673/3931394267.py:1: ExperimentalWarning:

plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

Final model

In [13]:
DIRECTORY = "./"
EPOCHS = 5
PATIENCE_EPOCHS = 3

model = MNISTModel(fashion_mnist.dims, fashion_mnist.num_classes, **best_params)

progressbar_theme = RichProgressBarTheme(
    description="black",
    progress_bar="bright_blue",
    progress_bar_finished="bright_blue",
    progress_bar_pulse="bright_blue",
    batch_progress="bright_blue",
    time="red",
    processing_speed="red",
    metrics="green",
)
progressbar_callback = callbacks.RichProgressBar(theme=progressbar_theme)
model_summary_callback = callbacks.RichModelSummary(max_depth=2)
checkpoint_callback = callbacks.ModelCheckpoint(
    filename="best_weights", save_top_k=1, monitor="val_loss", mode="min"
)
early_stopping_callback = callbacks.EarlyStopping(
    monitor="val_loss", mode="min", patience=PATIENCE_EPOCHS
)
swa_callback = callbacks.StochasticWeightAveraging(1e-2)

# Create trainer
trainer = Trainer(
    default_root_dir=DIRECTORY,
    accelerator="auto",
    devices="auto",
    max_epochs=EPOCHS,
    callbacks=[
        progressbar_callback,
        model_summary_callback,
        checkpoint_callback,
        early_stopping_callback,
        swa_callback,
    ],
    logger=CSVLogger(save_dir=DIRECTORY),
    precision="16-mixed",
)

# Train model
trainer.fit(model, fashion_mnist)
plot_training(trainer)
Using 16bit Automatic Mixed Precision (AMP)
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: ./lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃    Name     Type                Params ┃
┡━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ metric  │ MulticlassAccuracy │      0 │
│ 1 │ model   │ Sequential         │  5.0 M │
│ 2 │ model.0 │ Flatten            │      0 │
│ 3 │ model.1 │ Linear             │  1.5 M │
│ 4 │ model.2 │ Hardswish          │      0 │
│ 5 │ model.3 │ Dropout            │      0 │
│ 6 │ model.4 │ Linear             │  3.5 M │
│ 7 │ model.5 │ Hardswish          │      0 │
│ 8 │ model.6 │ Dropout            │      0 │
│ 9 │ model.7 │ Linear             │ 18.7 K │
└───┴─────────┴────────────────────┴────────┘
Trainable params: 5.0 M                                                                                            
Non-trainable params: 0                                                                                            
Total params: 5.0 M                                                                                                
Total estimated model params size (MB): 19                                                                         
Output()
`Trainer.fit` stopped: `max_epochs=5` reached.


No description has been provided for this image
In [14]:
# Test final model
checkpoint = checkpoint_callback.best_model_path
model = MNISTModel.load_from_checkpoint(checkpoint)
fashion_mnist.setup(stage="test")
test_results = trainer.test(model, fashion_mnist)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│         test_loss             0.3548879027366638     │
│        test_metric            0.8727999925613403     │
└───────────────────────────┴───────────────────────────┘