{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Training and testing" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Imports\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchvision import datasets, transforms\n", "import torchmetrics\n", "from lightning.pytorch import LightningModule, LightningDataModule, Trainer, callbacks\n", "from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme\n", "from lightning.pytorch.loggers import CSVLogger\n", "import onnxruntime" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Lightning module. Model with its usage\n", "class MNISTModel(LightningModule):\n", " def __init__(self, in_dims, num_classes, learning_rate=1e-3, hidden_size=64):\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.loss_fn = F.cross_entropy\n", " self.metric = torchmetrics.Accuracy(task=\"multiclass\", num_classes=num_classes)\n", " self.learning_rate = learning_rate\n", " self.save_hyperparameters()\n", "\n", " # Define PyTorch model\n", " channels, width, height = in_dims\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.Hardswish(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.Hardswish(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, num_classes),\n", " )\n", "\n", " def forward(self, x):\n", " logits = self.model(x)\n", " return torch.softmax(logits, dim=1)\n", "\n", " def training_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self.model(x)\n", " loss = self.loss_fn(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " metric = self.metric(preds, y)\n", "\n", " self.log(\"train_loss\", loss, prog_bar=False, on_step=False, on_epoch=True)\n", " self.log(\"train_metric\", metric, prog_bar=False, on_step=False, on_epoch=True)\n", "\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self.model(x)\n", " loss = self.loss_fn(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " metric = self.metric(preds, y)\n", "\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_metric\", metric, prog_bar=True)\n", "\n", " def test_step(self, batch, batch_idx):\n", " x, y = batch\n", " logits = self.model(x)\n", " loss = self.loss_fn(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " metric = self.metric(preds, y)\n", "\n", " self.log(\"test_loss\", loss, prog_bar=True)\n", " self.log(\"test_metric\", metric, prog_bar=True)\n", "\n", " def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):\n", " x, y = batch\n", " logits = self.model(x)\n", " return torch.softmax(logits, dim=1)\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)\n", " return optimizer\n", "\n", "\n", "# LightningDataModule. Data management\n", "class MNISTData(LightningDataModule):\n", " def __init__(self, data_dir=\"./\", batch_size=64, num_workers=4):\n", " super().__init__()\n", " self.data_dir = data_dir\n", " self.batch_size = batch_size\n", " self.dims = (1, 28, 28)\n", " self.num_classes = 10\n", " self.num_workers = num_workers\n", " self.transform = transforms.Compose(\n", " [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]\n", " )\n", "\n", " def prepare_data(self):\n", " # download\n", " datasets.FashionMNIST(self.data_dir, train=True, download=True)\n", " datasets.FashionMNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage=None):\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = datasets.FashionMNIST(\n", " self.data_dir, train=True, transform=self.transform\n", " )\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage == \"predict\" or stage is None:\n", " self.mnist_test = datasets.FashionMNIST(\n", " self.data_dir, train=False, transform=self.transform\n", " )\n", "\n", " def train_dataloader(self):\n", " return DataLoader(\n", " self.mnist_train,\n", " batch_size=self.batch_size,\n", " shuffle=True,\n", " num_workers=self.num_workers,\n", " )\n", "\n", " def val_dataloader(self):\n", " return DataLoader(\n", " self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers\n", " )\n", "\n", " def test_dataloader(self):\n", " return DataLoader(\n", " self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers\n", " )\n", "\n", " def predict_dataloader(self):\n", " return DataLoader(\n", " self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers\n", " )\n", "\n", "\n", "# Plot loss and metrics after training and validation\n", "def plot_training(trainer: Trainer):\n", " metrics = pd.read_csv(\n", " f\"{trainer.logger.log_dir}/metrics.csv\",\n", " usecols=[\"epoch\", \"val_loss\", \"val_metric\", \"train_loss\", \"train_metric\"],\n", " )\n", " metrics.set_index(\"epoch\", inplace=True)\n", " val_metrics = metrics[[\"val_loss\", \"val_metric\"]].dropna(axis=0)\n", " train_metrics = metrics[[\"train_loss\", \"train_metric\"]].dropna(axis=0)\n", " metrics = train_metrics.join(val_metrics)\n", " metrics.plot(\n", " subplots=[(\"val_loss\", \"train_loss\"), (\"val_metric\", \"train_metric\")],\n", " figsize=(8, 6),\n", " title=\"Training results\",\n", " ylabel=\"Value\",\n", " )" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "/home/nirogu/.local/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", " warning_cache.warn(\n", "Running in `fast_dev_run` mode: will run the requested loop using 3 batch(es). Logging and checkpointing is suppressed.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------\n", "0 | metric | MulticlassAccuracy | 0 \n", "1 | model | Sequential | 55.1 K\n", "----------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8343f81a58114187b2f0b23668fbfbfc", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c2c455b641b643078a6b0ec4b47176f9", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Validation: 0it [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_steps=3` reached.\n" ] } ], "source": [ "PATH_DATASETS = \"data\"\n", "BATCH_SIZE = 256 if torch.cuda.is_available() else 64\n", "\n", "# Instanciate lightning objects\n", "fashion_mnist = MNISTData(PATH_DATASETS, BATCH_SIZE)\n", "model = MNISTModel(fashion_mnist.dims, fashion_mnist.num_classes)\n", "# model = torch.compile(model, mode=\"reduce-overhead\")\n", "\n", "# Tune hyperparameters (find batch size and learning rate). Doesnt seem to work well atm\n", "# trainer = Trainer(accelerator='auto', auto_lr_find=True, auto_scale_batch_size='binsearch')\n", "# trainer.tune(model, fashion_mnist, lr_find_kwargs={'early_stop_threshold': None})\n", "\n", "# Check everything works (fast run to check nothing crashes)\n", "debugging_epochs = 3\n", "trainer = Trainer(\n", " fast_dev_run=debugging_epochs,\n", " accelerator=\"auto\",\n", " log_every_n_steps=debugging_epochs,\n", ") # profiler='simple', callbacks=callbacks.DeviceStatsMonitor()\n", "trainer.fit(model, fashion_mnist)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using 16bit Automatic Mixed Precision (AMP)\n", "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "Missing logger folder: ./lightning_logs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "text/html": [ "
┏━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃    Name     Type                Params ┃\n",
       "┡━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│ 0 │ metric  │ MulticlassAccuracy │      0 │\n",
       "│ 1 │ model   │ Sequential         │ 55.1 K │\n",
       "│ 2 │ model.0 │ Flatten            │      0 │\n",
       "│ 3 │ model.1 │ Linear             │ 50.2 K │\n",
       "│ 4 │ model.2 │ Hardswish          │      0 │\n",
       "│ 5 │ model.3 │ Dropout            │      0 │\n",
       "│ 6 │ model.4 │ Linear             │  4.2 K │\n",
       "│ 7 │ model.5 │ Hardswish          │      0 │\n",
       "│ 8 │ model.6 │ Dropout            │      0 │\n",
       "│ 9 │ model.7 │ Linear             │    650 │\n",
       "└───┴─────────┴────────────────────┴────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n", "┡━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ metric │ MulticlassAccuracy │ 0 │\n", "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ model │ Sequential │ 55.1 K │\n", "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ model.0 │ Flatten │ 0 │\n", "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ model.1 │ Linear │ 50.2 K │\n", "│\u001b[2m \u001b[0m\u001b[2m4\u001b[0m\u001b[2m \u001b[0m│ model.2 │ Hardswish │ 0 │\n", "│\u001b[2m \u001b[0m\u001b[2m5\u001b[0m\u001b[2m \u001b[0m│ model.3 │ Dropout │ 0 │\n", "│\u001b[2m \u001b[0m\u001b[2m6\u001b[0m\u001b[2m \u001b[0m│ model.4 │ Linear │ 4.2 K │\n", "│\u001b[2m \u001b[0m\u001b[2m7\u001b[0m\u001b[2m \u001b[0m│ model.5 │ Hardswish │ 0 │\n", "│\u001b[2m \u001b[0m\u001b[2m8\u001b[0m\u001b[2m \u001b[0m│ model.6 │ Dropout │ 0 │\n", "│\u001b[2m \u001b[0m\u001b[2m9\u001b[0m\u001b[2m \u001b[0m│ model.7 │ Linear │ 650 │\n", "└───┴─────────┴────────────────────┴────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
Trainable params: 55.1 K                                                                                           \n",
       "Non-trainable params: 0                                                                                            \n",
       "Total params: 55.1 K                                                                                               \n",
       "Total estimated model params size (MB): 0                                                                          \n",
       "
\n" ], "text/plain": [ "\u001b[1mTrainable params\u001b[0m: 55.1 K \n", "\u001b[1mNon-trainable params\u001b[0m: 0 \n", "\u001b[1mTotal params\u001b[0m: 55.1 K \n", "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0 \n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1222340643384bdfb9911fc8707337e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "DIRECTORY = \"./\"\n", "EPOCHS = 5\n", "PATIENCE_EPOCHS = 3\n", "\n", "progressbar_theme = RichProgressBarTheme(\n", " description=\"black\",\n", " progress_bar=\"bright_blue\",\n", " progress_bar_finished=\"bright_blue\",\n", " progress_bar_pulse=\"bright_blue\",\n", " batch_progress=\"bright_blue\",\n", " time=\"red\",\n", " processing_speed=\"red\",\n", " metrics=\"green\",\n", ")\n", "progressbar_callback = callbacks.RichProgressBar(theme=progressbar_theme)\n", "model_summary_callback = callbacks.RichModelSummary(max_depth=2)\n", "checkpoint_callback = callbacks.ModelCheckpoint(\n", " filename=\"best_weights\", save_top_k=1, monitor=\"val_loss\", mode=\"min\"\n", ")\n", "early_stopping_callback = callbacks.EarlyStopping(\n", " monitor=\"val_loss\", mode=\"min\", patience=PATIENCE_EPOCHS\n", ")\n", "swa_callback = callbacks.StochasticWeightAveraging(1e-2)\n", "\n", "# Create trainer\n", "trainer = Trainer(\n", " default_root_dir=DIRECTORY,\n", " accelerator=\"auto\",\n", " devices=\"auto\",\n", " max_epochs=EPOCHS,\n", " callbacks=[\n", " progressbar_callback,\n", " model_summary_callback,\n", " checkpoint_callback,\n", " early_stopping_callback,\n", " swa_callback,\n", " ],\n", " logger=CSVLogger(save_dir=DIRECTORY),\n", " precision=\"16-mixed\",\n", ")\n", "\n", "# Train model\n", "trainer.fit(model, fashion_mnist)\n", "plot_training(trainer)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ca61c473eabf4f1381c3413db7501265", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃   Runningstage.testing                               ┃\n",
       "┃          metric                  DataLoader 0        ┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│         test_loss             0.3557029664516449     │\n",
       "│        test_metric             0.871999979019165     │\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "
\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Runningstage.testing \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\n", "┃\u001b[1m \u001b[0m\u001b[1m metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.3557029664516449 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.871999979019165 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Test model (with trainer)\n", "checkpoint = checkpoint_callback.best_model_path\n", "model = MNISTModel.load_from_checkpoint(checkpoint)\n", "test_results = trainer.test(model, fashion_mnist)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "fccb8696c5754735b4ccbd3926a0a7dd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "
\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "
\n",
       "
\n" ], "text/plain": [ "\n" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Accuracy: 0.871999979019165\n", "Balanced accuracy: 0.871999979019165\n", "Confusion matrix:\n" ] }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Test model (without trainer)\n", "checkpoint = checkpoint_callback.best_model_path\n", "model = MNISTModel.load_from_checkpoint(checkpoint)\n", "\n", "classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]\n", "preds = torch.cat(trainer.predict(model, fashion_mnist))\n", "y = fashion_mnist.mnist_test.targets\n", "\n", "print(\n", " \"Accuracy:\",\n", " torchmetrics.functional.accuracy(\n", " preds, y, task=\"multiclass\", num_classes=fashion_mnist.num_classes\n", " ).item(),\n", ")\n", "print(\n", " \"Balanced accuracy:\",\n", " torchmetrics.functional.accuracy(\n", " preds,\n", " y,\n", " task=\"multiclass\",\n", " average=\"macro\",\n", " num_classes=fashion_mnist.num_classes,\n", " ).item(),\n", ")\n", "print(\"Confusion matrix:\")\n", "cm = sns.heatmap(\n", " torchmetrics.functional.confusion_matrix(\n", " preds,\n", " y,\n", " task=\"multiclass\",\n", " normalize=\"true\",\n", " num_classes=fashion_mnist.num_classes,\n", " )\n", " * 100,\n", " annot=True,\n", " fmt=\".1f\",\n", " cmap=\"YlOrRd\",\n", " cbar=False,\n", " xticklabels=classes,\n", " yticklabels=classes,\n", ")\n", "cm.set(xlabel=\"Predicted\", ylabel=\"True\");" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Production: Lightning" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAK1UlEQVR4nO3czW6V9RrG4acuaukHFVMtWBM/EhFNY5yYoBJnROYaBw6NEw/Ak3DqiZg4YcAZODDGEYkONEgQo0UKxa5auvZk5x5u+vyzqVWva8zNWxYtP9+Bz9xsNpsVAFTVY3/1FwDA8SEKAIQoABCiAECIAgAhCgCEKAAQogBAnDjsL5ybm3uUXwcAj9hh/l9lbwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECc+Ku/AOB4mUwm7c3BwUF7M5vN2ptRCwsL7c10Om1vXnrppfamqur7778f2j0K3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFdS+Ueam5s7ks3IddBnn322vamqeuutt9qbK1eutDc7OzvtzXE3cvF0xPvvvz+0++yzz/7PX8k4bwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIA4SAe/NfIcbsR77zzztDuwoUL7c3GxkZ78/nnn7c3x936+np7c/ny5fZme3u7vTluvCkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhIN4/CNNJpP2Zn9/v71544032ptXX321vamqunXrVntz7ty59uaLL75ob7a2ttqbxcXF9qaq6scff2xv1tbW2pvV1dX25qeffmpvjhtvCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhIB7H3mOP9f/bZeS43fLycnvzwQcftDfT6bS9qao6efJke3Pq1Kn2Zm5urr0Z+TsaeU5V1ebmZntz/fr19ub27dvtzYkTf/9/Ur0pABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABB//5N+fwMj1yBns9nQs0auVY48a2QzmUzam6qqBw8eDO26Pvnkk/bm559/bm92d3fbm6qqF154ob0Zuax669at9mbk7/bg4KC9qara2dlpb/b29tqb1dXV9mZhYaG9qRq70DvyORyGNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGA+FcfxDuqQ3Wjx+1GjB4Z6xo5gHZUh+2qqj788MP25uzZs+3N119/3d7Mz8+3N1VVp0+fbm9+++239mZra6u9eeqpp9qbU6dOtTdV44cVu0aOSy4tLQ0969y5c+3NN998M/Ssh/GmAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABD/6oN4R3WobuSw1simauzo3MjncJTH7T766KP25vz58+3N9evX25uRQ3AjhxirqhYXF9ubGzdutDcjh+pGDjHev3+/vamqOnnyZHtzVMcvR12+fLm9cRAPgEdOFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYA4dgfxRg/BjRg5eDVyWGvkWNjI5ihtbGy0N++9997Qs0YOwX333XftzcrKSnuzsLDQ3qytrbU3VVV7e3vtzcj3+NLSUnszYvSo4nQ6PZJn7ezstDejP7cXL14c2j0K3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUA4tAH8SaTSfs3HzlCddwPwY0cGBvx9NNPD+2ef/759uaVV15pb5555pn2ZuSgW1XV9vZ2e3P69On2ZnV1tb2Zn59vb0aO6FWN/WyMfD+M/Jl+//339ubPP/9sb6rGPoeRQ5t//PFHezPy72RV1d27d9ubzc3NoWc9jDcFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAOLQV1JHLp6OOHPmzNBu5Brk8vLykWwWFxfbmxdffLG9qapaWlpqb0auVd67d6+9GblUWVX1xBNPtDcjn/n+/n57M/J5379/v72pqppOp+3N448/3t7cvHmzvRn5Oxr57Kqqbt++3d6srKy0N08++WR7s7Oz095UVZ09e7a9WVtbG3rWw3hTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIhDH8QbcenSpfZmY2Nj6FkjR93W19fbm5GjbgcHB+3NyJ+nquru3bvtzcixsJEDXnNzc+1NVdXCwkJ7M3I0beTvduSzm0wm7U3V2LG1ke+HO3futDcjP0tHaeT7YeTnduQQY9XY4cKRA46H4U0BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIA59EO/dd99t/+Yff/xxe3Pt2rX2pqrq5s2b7c329nZ7M3LMbG9v70ieM2rkaNrIAa8HDx60N1VVq6ur7c3I8b2RY2YjR9Pm5+fbm6qxI4RnzpxpbzY3N9ubkT/TUX6PjxwTXFpaam92d3fbm6qxr++XX34ZetbDeFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiEMfxPvqq6/av/mbb77Z3rz22mvtTVXVxYsXh3Zd+/v77c3Iwbmtra32ZnR3586d9mbkIN7IkbqqqrW1tfbm/Pnz7c3IAbSRY32z2ay9qap6/fXX25tvv/22vfnhhx/am0uXLrU3CwsL7U3V+OfXNfKzfuPGjaFnjRznXFlZGXrWw3hTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIi52SGvS40eMzsqI8ehLly40N68/PLL7c3bb7/d3qyvr7c3VWMH2paXl9ubke+H0UNmBwcH7c3IYcBr1661N1evXm1vrly50t5UVe3u7g7tjsKXX37Z3jz33HNDz/r111/bm5GjlCObkSN6VVXT6bS9+fTTT9ube/fuPfTXeFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIP4xV1IB+N8O88+9NwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIgTh/2Fs9nsUX4dABwD3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFACI/wC8gLF1VGuA8QAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "checkpoint = checkpoint_callback.best_model_path\n", "model = MNISTModel.load_from_checkpoint(checkpoint)\n", "\n", "classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\", train=False, download=True, transform=transforms.ToTensor()\n", ")\n", "x, y = test_data[0][0], test_data[0][1]\n", "plt.imshow(x.numpy()[0], cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()\n", "\n", "model.eval()\n", "with torch.inference_mode():\n", " pred = model(x.to(\"cuda\"))\n", " predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", " print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Production: TorchScript\n", "## Save model" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "checkpoint = checkpoint_callback.best_model_path\n", "model = MNISTModel.load_from_checkpoint(checkpoint)\n", "script = model.to_torchscript()\n", "torch.jit.save(script, \"model.pt\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAK1UlEQVR4nO3czW6V9RrG4acuaukHFVMtWBM/EhFNY5yYoBJnROYaBw6NEw/Ak3DqiZg4YcAZODDGEYkONEgQo0UKxa5auvZk5x5u+vyzqVWva8zNWxYtP9+Bz9xsNpsVAFTVY3/1FwDA8SEKAIQoABCiAECIAgAhCgCEKAAQogBAnDjsL5ybm3uUXwcAj9hh/l9lbwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECc+Ku/AOB4mUwm7c3BwUF7M5vN2ptRCwsL7c10Om1vXnrppfamqur7778f2j0K3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFdS+Ueam5s7ks3IddBnn322vamqeuutt9qbK1eutDc7OzvtzXE3cvF0xPvvvz+0++yzz/7PX8k4bwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIA4SAe/NfIcbsR77zzztDuwoUL7c3GxkZ78/nnn7c3x936+np7c/ny5fZme3u7vTluvCkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhIN4/CNNJpP2Zn9/v71544032ptXX321vamqunXrVntz7ty59uaLL75ob7a2ttqbxcXF9qaq6scff2xv1tbW2pvV1dX25qeffmpvjhtvCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhIB7H3mOP9f/bZeS43fLycnvzwQcftDfT6bS9qao6efJke3Pq1Kn2Zm5urr0Z+TsaeU5V1ebmZntz/fr19ub27dvtzYkTf/9/Ur0pABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABB//5N+fwMj1yBns9nQs0auVY48a2QzmUzam6qqBw8eDO26Pvnkk/bm559/bm92d3fbm6qqF154ob0Zuax669at9mbk7/bg4KC9qara2dlpb/b29tqb1dXV9mZhYaG9qRq70DvyORyGNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGA+FcfxDuqQ3Wjx+1GjB4Z6xo5gHZUh+2qqj788MP25uzZs+3N119/3d7Mz8+3N1VVp0+fbm9+++239mZra6u9eeqpp9qbU6dOtTdV44cVu0aOSy4tLQ0969y5c+3NN998M/Ssh/GmAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABD/6oN4R3WobuSw1simauzo3MjncJTH7T766KP25vz58+3N9evX25uRQ3AjhxirqhYXF9ubGzdutDcjh+pGDjHev3+/vamqOnnyZHtzVMcvR12+fLm9cRAPgEdOFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYA4dgfxRg/BjRg5eDVyWGvkWNjI5ihtbGy0N++9997Qs0YOwX333XftzcrKSnuzsLDQ3qytrbU3VVV7e3vtzcj3+NLSUnszYvSo4nQ6PZJn7ezstDejP7cXL14c2j0K3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUA4tAH8SaTSfs3HzlCddwPwY0cGBvx9NNPD+2ef/759uaVV15pb5555pn2ZuSgW1XV9vZ2e3P69On2ZnV1tb2Zn59vb0aO6FWN/WyMfD+M/Jl+//339ubPP/9sb6rGPoeRQ5t//PFHezPy72RV1d27d9ubzc3NoWc9jDcFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAOLQV1JHLp6OOHPmzNBu5Brk8vLykWwWFxfbmxdffLG9qapaWlpqb0auVd67d6+9GblUWVX1xBNPtDcjn/n+/n57M/J5379/v72pqppOp+3N448/3t7cvHmzvRn5Oxr57Kqqbt++3d6srKy0N08++WR7s7Oz095UVZ09e7a9WVtbG3rWw3hTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIhDH8QbcenSpfZmY2Nj6FkjR93W19fbm5GjbgcHB+3NyJ+nquru3bvtzcixsJEDXnNzc+1NVdXCwkJ7M3I0beTvduSzm0wm7U3V2LG1ke+HO3futDcjP0tHaeT7YeTnduQQY9XY4cKRA46H4U0BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIA59EO/dd99t/+Yff/xxe3Pt2rX2pqrq5s2b7c329nZ7M3LMbG9v70ieM2rkaNrIAa8HDx60N1VVq6ur7c3I8b2RY2YjR9Pm5+fbm6qxI4RnzpxpbzY3N9ubkT/TUX6PjxwTXFpaam92d3fbm6qxr++XX34ZetbDeFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiEMfxPvqq6/av/mbb77Z3rz22mvtTVXVxYsXh3Zd+/v77c3Iwbmtra32ZnR3586d9mbkIN7IkbqqqrW1tfbm/Pnz7c3IAbSRY32z2ay9qap6/fXX25tvv/22vfnhhx/am0uXLrU3CwsL7U3V+OfXNfKzfuPGjaFnjRznXFlZGXrWw3hTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIi52SGvS40eMzsqI8ehLly40N68/PLL7c3bb7/d3qyvr7c3VWMH2paXl9ubke+H0UNmBwcH7c3IYcBr1661N1evXm1vrly50t5UVe3u7g7tjsKXX37Z3jz33HNDz/r111/bm5GjlCObkSN6VVXT6bS9+fTTT9ube/fuPfTXeFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIP4xV1IB+N8O88+9NwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIgTh/2Fs9nsUX4dABwD3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFACI/wC8gLF1VGuA8QAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "model = torch.jit.load(\"model.pt\")\n", "\n", "\n", "@torch.inference_mode()\n", "def predict(x: torch.Tensor) -> torch.Tensor:\n", " return model(x)\n", "\n", "\n", "classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\", train=False, download=True, transform=transforms.ToTensor()\n", ")\n", "x, y = test_data[0][0], test_data[0][1]\n", "plt.imshow(x.numpy()[0], cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()\n", "\n", "pred = predict(x.to(\"cuda\"))\n", "predicted, actual = classes[pred[0].argmax(0)], classes[y]\n", "print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Production: ONNX\n", "## Save model" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "============= Diagnostic Run torch.onnx.export version 2.0.1+cu117 =============\n", "verbose: False, log level: Level.ERROR\n", "======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================\n", "\n" ] } ], "source": [ "checkpoint = checkpoint_callback.best_model_path\n", "model = MNISTModel.load_from_checkpoint(checkpoint)\n", "input_sample = datasets.FashionMNIST(\n", " root=\"data\", train=False, download=True, transform=transforms.ToTensor()\n", ")\n", "input_sample = input_sample[0][0]\n", "model.to_onnx(\"model.onnx\", input_sample, export_params=True)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAGFCAYAAAASI+9IAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAK1UlEQVR4nO3czW6V9RrG4acuaukHFVMtWBM/EhFNY5yYoBJnROYaBw6NEw/Ak3DqiZg4YcAZODDGEYkONEgQo0UKxa5auvZk5x5u+vyzqVWva8zNWxYtP9+Bz9xsNpsVAFTVY3/1FwDA8SEKAIQoABCiAECIAgAhCgCEKAAQogBAnDjsL5ybm3uUXwcAj9hh/l9lbwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECc+Ku/AOB4mUwm7c3BwUF7M5vN2ptRCwsL7c10Om1vXnrppfamqur7778f2j0K3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFdS+Ueam5s7ks3IddBnn322vamqeuutt9qbK1eutDc7OzvtzXE3cvF0xPvvvz+0++yzz/7PX8k4bwoAhCgAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIA4SAe/NfIcbsR77zzztDuwoUL7c3GxkZ78/nnn7c3x936+np7c/ny5fZme3u7vTluvCkAEKIAQIgCACEKAIQoABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhIN4/CNNJpP2Zn9/v71544032ptXX321vamqunXrVntz7ty59uaLL75ob7a2ttqbxcXF9qaq6scff2xv1tbW2pvV1dX25qeffmpvjhtvCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABCiAECIAgDhIB7H3mOP9f/bZeS43fLycnvzwQcftDfT6bS9qao6efJke3Pq1Kn2Zm5urr0Z+TsaeU5V1ebmZntz/fr19ub27dvtzYkTf/9/Ur0pABCiAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABB//5N+fwMj1yBns9nQs0auVY48a2QzmUzam6qqBw8eDO26Pvnkk/bm559/bm92d3fbm6qqF154ob0Zuax669at9mbk7/bg4KC9qara2dlpb/b29tqb1dXV9mZhYaG9qRq70DvyORyGNwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGA+FcfxDuqQ3Wjx+1GjB4Z6xo5gHZUh+2qqj788MP25uzZs+3N119/3d7Mz8+3N1VVp0+fbm9+++239mZra6u9eeqpp9qbU6dOtTdV44cVu0aOSy4tLQ0969y5c+3NN998M/Ssh/GmAECIAgAhCgCEKAAQogBAiAIAIQoAhCgAEKIAQIgCACEKAIQoABD/6oN4R3WobuSw1simauzo3MjncJTH7T766KP25vz58+3N9evX25uRQ3AjhxirqhYXF9ubGzdutDcjh+pGDjHev3+/vamqOnnyZHtzVMcvR12+fLm9cRAPgEdOFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYA4dgfxRg/BjRg5eDVyWGvkWNjI5ihtbGy0N++9997Qs0YOwX333XftzcrKSnuzsLDQ3qytrbU3VVV7e3vtzcj3+NLSUnszYvSo4nQ6PZJn7ezstDejP7cXL14c2j0K3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUA4tAH8SaTSfs3HzlCddwPwY0cGBvx9NNPD+2ef/759uaVV15pb5555pn2ZuSgW1XV9vZ2e3P69On2ZnV1tb2Zn59vb0aO6FWN/WyMfD+M/Jl+//339ubPP/9sb6rGPoeRQ5t//PFHezPy72RV1d27d9ubzc3NoWc9jDcFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAOLQV1JHLp6OOHPmzNBu5Brk8vLykWwWFxfbmxdffLG9qapaWlpqb0auVd67d6+9GblUWVX1xBNPtDcjn/n+/n57M/J5379/v72pqppOp+3N448/3t7cvHmzvRn5Oxr57Kqqbt++3d6srKy0N08++WR7s7Oz095UVZ09e7a9WVtbG3rWw3hTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIhDH8QbcenSpfZmY2Nj6FkjR93W19fbm5GjbgcHB+3NyJ+nquru3bvtzcixsJEDXnNzc+1NVdXCwkJ7M3I0beTvduSzm0wm7U3V2LG1ke+HO3futDcjP0tHaeT7YeTnduQQY9XY4cKRA46H4U0BgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIA59EO/dd99t/+Yff/xxe3Pt2rX2pqrq5s2b7c329nZ7M3LMbG9v70ieM2rkaNrIAa8HDx60N1VVq6ur7c3I8b2RY2YjR9Pm5+fbm6qxI4RnzpxpbzY3N9ubkT/TUX6PjxwTXFpaam92d3fbm6qxr++XX34ZetbDeFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQAiEMfxPvqq6/av/mbb77Z3rz22mvtTVXVxYsXh3Zd+/v77c3Iwbmtra32ZnR3586d9mbkIN7IkbqqqrW1tfbm/Pnz7c3IAbSRY32z2ay9qap6/fXX25tvv/22vfnhhx/am0uXLrU3CwsL7U3V+OfXNfKzfuPGjaFnjRznXFlZGXrWw3hTACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIi52SGvS40eMzsqI8ehLly40N68/PLL7c3bb7/d3qyvr7c3VWMH2paXl9ubke+H0UNmBwcH7c3IYcBr1661N1evXm1vrly50t5UVe3u7g7tjsKXX37Z3jz33HNDz/r111/bm5GjlCObkSN6VVXT6bS9+fTTT9ube/fuPfTXeFMAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIP4xV1IB+N8O88+9NwUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFAAIUQAgRAGAEAUAQhQACFEAIEQBgBAFAEIUAIgTh/2Fs9nsUX4dABwD3hQACFEAIEQBgBAFAEIUAAhRACBEAYAQBQBCFACI/wC8gLF1VGuA8QAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n" ] } ], "source": [ "ort_session = onnxruntime.InferenceSession(\"model.onnx\")\n", "\n", "classes = [\n", " \"T-shirt/top\",\n", " \"Trouser\",\n", " \"Pullover\",\n", " \"Dress\",\n", " \"Coat\",\n", " \"Sandal\",\n", " \"Shirt\",\n", " \"Sneaker\",\n", " \"Bag\",\n", " \"Ankle boot\",\n", "]\n", "test_data = datasets.FashionMNIST(\n", " root=\"data\", train=False, download=True, transform=transforms.ToTensor()\n", ")\n", "x, y = test_data[0][0], test_data[0][1]\n", "plt.imshow(x.numpy()[0], cmap=\"gray\")\n", "plt.axis(\"off\")\n", "plt.show()\n", "\n", "input_name = ort_session.get_inputs()[0].name\n", "ort_inputs = {input_name: x.numpy()}\n", "ort_outs = ort_session.run(None, ort_inputs)[0]\n", "predicted, actual = classes[ort_outs[0].argmax(0)], classes[y]\n", "print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.3" }, "vscode": { "interpreter": { "hash": "767d51c1340bd893661ea55ea3124f6de3c7a262a8b4abca0554b478b1e2ff90" } } }, "nbformat": 4, "nbformat_minor": 4 }