3.9 MiB
3.9 MiB
Imports¶
In [1]:
import matplotlib.pyplot as plt import seaborn as sns import pandas as pd import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, random_split from torchvision import datasets, transforms import torchmetrics from pytorch_lightning import LightningModule, LightningDataModule, Trainer, callbacks from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.loggers import CSVLogger import optuna from optuna.integration import PyTorchLightningPruningCallback from optuna import visualization
Lightning classes and functions¶
In [2]:
# Lightning module. Model with its usage class MNISTModel(LightningModule): def __init__(self, in_dims, num_classes, learning_rate=1e-3, hidden_size=64): super().__init__() # Set our init args as class attributes self.loss_fn = F.cross_entropy self.metric = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) self.learning_rate = learning_rate self.save_hyperparameters() # Define PyTorch model channels, width, height = in_dims self.model = nn.Sequential( nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.Hardswish(), nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), nn.Hardswish(), nn.Dropout(0.1), nn.Linear(hidden_size, num_classes), ) def forward(self, x): logits = self.model(x) return torch.softmax(logits, dim=1) def training_step(self, batch, batch_idx): x, y = batch logits = self.model(x) loss = self.loss_fn(logits, y) preds = torch.argmax(logits, dim=1) metric = self.metric(preds, y) self.log("train_loss", loss, prog_bar=False, on_step=False, on_epoch=True) self.log("train_metric", metric, prog_bar=False, on_step=False, on_epoch=True) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self.model(x) loss = self.loss_fn(logits, y) preds = torch.argmax(logits, dim=1) metric = self.metric(preds, y) self.log("val_loss", loss, prog_bar=True) self.log("val_metric", metric, prog_bar=True) def test_step(self, batch, batch_idx): x, y = batch logits = self.model(x) loss = self.loss_fn(logits, y) preds = torch.argmax(logits, dim=1) metric = self.metric(preds, y) self.log("test_loss", loss, prog_bar=True) self.log("test_metric", metric, prog_bar=True) def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): x, y = batch logits = self.model(x) return torch.softmax(logits, dim=1) def configure_optimizers(self): optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate) return optimizer # LightningDataModule. Data management class MNISTData(LightningDataModule): def __init__(self, data_dir="./", batch_size=64, num_workers=4): super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.dims = (1, 28, 28) self.num_classes = 10 self.num_workers = num_workers self.transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) def prepare_data(self): # download datasets.FashionMNIST(self.data_dir, train=True, download=True) datasets.FashionMNIST(self.data_dir, train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == "fit" or stage is None: mnist_full = datasets.FashionMNIST( self.data_dir, train=True, transform=self.transform ) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) if stage == "test" or stage == "predict" or stage is None: self.mnist_test = datasets.FashionMNIST( self.data_dir, train=False, transform=self.transform ) def train_dataloader(self): return DataLoader( self.mnist_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, ) def val_dataloader(self): return DataLoader( self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers ) def test_dataloader(self): return DataLoader( self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers ) def predict_dataloader(self): return DataLoader( self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers ) # Plot loss and metrics after training and validation def plot_training(trainer: Trainer): metrics = pd.read_csv( f"{trainer.logger.log_dir}/metrics.csv", usecols=["epoch", "val_loss", "val_metric", "train_loss", "train_metric"], ) metrics.set_index("epoch", inplace=True) val_metrics = metrics[["val_loss", "val_metric"]].dropna(axis=0) train_metrics = metrics[["train_loss", "train_metric"]].dropna(axis=0) metrics = train_metrics.join(val_metrics) metrics.plot( subplots=[("val_loss", "train_loss"), ("val_metric", "train_metric")], figsize=(8, 6), title="Training results", ylabel="Value", )
Bayesian optimization with Optuna¶
In [ ]:
# Optuna objective function def objective(trial: optuna.trial.Trial) -> float: # Select hyperparameters learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-1, log=True) hidden_size = trial.suggest_int("hidden_size", 2, 2048, log=True) epochs = 5 model = MNISTModel(fashion_mnist.dims, fashion_mnist.num_classes, learning_rate, hidden_size) monitor = "val_loss" # Create trainer swa_callback = callbacks.StochasticWeightAveraging(1e-2) optuna_callback = PyTorchLightningPruningCallback(trial, monitor=monitor) trainer = Trainer( accelerator="auto", devices="auto", max_epochs=epochs, callbacks=[optuna_callback, swa_callback], precision="16-mixed", log_every_n_steps=20, logger=False, enable_checkpointing=False, enable_progress_bar=False, enable_model_summary=False, ) # Train model trainer.fit(model, fashion_mnist) optuna_callback.check_pruned() result = trainer.callback_metrics[monitor].item() del model torch.cuda.empty_cache() return result # Program constants PATH_DATASETS = "data" BATCH_SIZE = 256 if torch.cuda.is_available() else 64 MAX_TIME = None N_TRIALS = 10 N_STARTUP_TRIALS = N_TRIALS // 5 # Instanciate lightning dataset fashion_mnist = MNISTData(PATH_DATASETS, BATCH_SIZE) fashion_mnist.prepare_data() fashion_mnist.setup(stage="fit") # Run optimization sampler = optuna.samplers.TPESampler(n_startup_trials=N_STARTUP_TRIALS) pruner = optuna.pruners.HyperbandPruner() study = optuna.create_study( study_name="bayesian_optimization", sampler=sampler, pruner=pruner, direction="minimize", ) study.optimize( objective, n_trials=N_TRIALS, timeout=MAX_TIME, n_jobs=1, catch=(RuntimeError, AssertionError), gc_after_trial=True, show_progress_bar=True, ) # Save optimization results study_df = study.trials_dataframe() study_df.to_csv("study_results.csv", index=False) # Print optimization results best_params = study.best_params print("Best params:") print(best_params) print("Best loss:") print(study.best_value)
Optimization visualization¶
In [4]:
visualization.plot_optimization_history(study)
In [5]:
visualization.plot_intermediate_values(study)
In [6]:
visualization.plot_parallel_coordinate(study)
In [7]:
visualization.plot_contour(study)
In [8]:
visualization.plot_param_importances(study)
In [9]:
visualization.plot_edf(study)
In [10]:
visualization.plot_rank(study)
/tmp/ipykernel_3673/1843785596.py:1: ExperimentalWarning: plot_rank is experimental (supported from v3.2.0). The interface can change in the future.
In [11]:
visualization.plot_slice(study)
In [12]:
visualization.plot_timeline(study)
/tmp/ipykernel_3673/3931394267.py:1: ExperimentalWarning: plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.
Final model¶
In [13]:
DIRECTORY = "./" EPOCHS = 5 PATIENCE_EPOCHS = 3 model = MNISTModel(fashion_mnist.dims, fashion_mnist.num_classes, **best_params) progressbar_theme = RichProgressBarTheme( description="black", progress_bar="bright_blue", progress_bar_finished="bright_blue", progress_bar_pulse="bright_blue", batch_progress="bright_blue", time="red", processing_speed="red", metrics="green", ) progressbar_callback = callbacks.RichProgressBar(theme=progressbar_theme) model_summary_callback = callbacks.RichModelSummary(max_depth=2) checkpoint_callback = callbacks.ModelCheckpoint( filename="best_weights", save_top_k=1, monitor="val_loss", mode="min" ) early_stopping_callback = callbacks.EarlyStopping( monitor="val_loss", mode="min", patience=PATIENCE_EPOCHS ) swa_callback = callbacks.StochasticWeightAveraging(1e-2) # Create trainer trainer = Trainer( default_root_dir=DIRECTORY, accelerator="auto", devices="auto", max_epochs=EPOCHS, callbacks=[ progressbar_callback, model_summary_callback, checkpoint_callback, early_stopping_callback, swa_callback, ], logger=CSVLogger(save_dir=DIRECTORY), precision="16-mixed", ) # Train model trainer.fit(model, fashion_mnist) plot_training(trainer)
Using 16bit Automatic Mixed Precision (AMP) Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback. GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs Missing logger folder: ./lightning_logs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
┏━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓ ┃ ┃ Name ┃ Type ┃ Params ┃ ┡━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩ │ 0 │ metric │ MulticlassAccuracy │ 0 │ │ 1 │ model │ Sequential │ 5.0 M │ │ 2 │ model.0 │ Flatten │ 0 │ │ 3 │ model.1 │ Linear │ 1.5 M │ │ 4 │ model.2 │ Hardswish │ 0 │ │ 5 │ model.3 │ Dropout │ 0 │ │ 6 │ model.4 │ Linear │ 3.5 M │ │ 7 │ model.5 │ Hardswish │ 0 │ │ 8 │ model.6 │ Dropout │ 0 │ │ 9 │ model.7 │ Linear │ 18.7 K │ └───┴─────────┴────────────────────┴────────┘
Trainable params: 5.0 M Non-trainable params: 0 Total params: 5.0 M Total estimated model params size (MB): 19
Output()
`Trainer.fit` stopped: `max_epochs=5` reached.
In [14]:
# Test final model checkpoint = checkpoint_callback.best_model_path model = MNISTModel.load_from_checkpoint(checkpoint) fashion_mnist.setup(stage="test") test_results = trainer.test(model, fashion_mnist)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Output()
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Test metric ┃ DataLoader 0 ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ test_loss │ 0.3548879027366638 │ │ test_metric │ 0.8727999925613403 │ └───────────────────────────┴───────────────────────────┘