{"cells": [{"cell_type": "markdown", "id": "b36d598d", "metadata": {"papermill": {"duration": 0.00325, "end_time": "2025-04-03T20:53:24.581994", "exception": false, "start_time": "2025-04-03T20:53:24.578744", "status": "completed"}, "tags": []}, "source": ["\n", "# Introduction to PyTorch Lightning\n", "\n", "* **Author:** Lightning.ai\n", "* **License:** CC BY-SA\n", "* **Generated:** 2025-04-03T20:53:18.167877\n", "\n", "In this notebook, we'll go over the basics of lightning by preparing models to train on the [MNIST Handwritten Digits dataset](https://en.wikipedia.org/wiki/MNIST_database).\n", "\n", "---\n", "Open in [{height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/mnist-hello-world.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://lightning.ai/docs/)\n", "| Join us [on Discord](https://discord.com/invite/tfXFetEZxv)"]}, {"cell_type": "markdown", "id": "27a8ad74", "metadata": {"papermill": {"duration": 0.00234, "end_time": "2025-04-03T20:53:24.586994", "exception": false, "start_time": "2025-04-03T20:53:24.584654", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "0def7e9d", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2025-04-03T20:53:24.592596Z", "iopub.status.busy": "2025-04-03T20:53:24.592415Z", "iopub.status.idle": "2025-04-03T20:53:25.767242Z", "shell.execute_reply": "2025-04-03T20:53:25.765748Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 1.179969, "end_time": "2025-04-03T20:53:25.769318", "exception": false, "start_time": "2025-04-03T20:53:24.589349", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}, {"name": "stdout", "output_type": "stream", "text": ["\r\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\r\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\r\n"]}], "source": ["! pip install --quiet \"torch>=1.8.1, <2.7\" \"matplotlib\" \"torchmetrics>=1.0, <1.8\" \"pandas\" \"torchmetrics >=0.11.0\" \"numpy <3.0\" \"pytorch-lightning >=2.0,<2.6\" \"torchvision\" \"seaborn\""]}, {"cell_type": "code", "execution_count": 2, "id": "9a148e57", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:53:25.776575Z", "iopub.status.busy": "2025-04-03T20:53:25.776033Z", "iopub.status.idle": "2025-04-03T20:53:29.333206Z", "shell.execute_reply": "2025-04-03T20:53:29.331870Z"}, "papermill": {"duration": 3.563765, "end_time": "2025-04-03T20:53:29.335766", "exception": false, "start_time": "2025-04-03T20:53:25.772001", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "# ------------------- Preliminaries ------------------- #\n", "import os\n", "from dataclasses import dataclass\n", "from typing import Tuple\n", "\n", "import pandas as pd\n", "import pytorch_lightning as pl\n", "import seaborn as sn\n", "import torch\n", "from IPython.display import display\n", "from pytorch_lightning.loggers import CSVLogger\n", "from torch import nn\n", "from torch.nn import functional as F\n", "from torch.utils.data import DataLoader, random_split\n", "from torchmetrics import Accuracy\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "\n", "# ------------------- Configuration ------------------- #\n", "\n", "\n", "@dataclass\n", "class Config:\n", " \"\"\"Configuration options for the Lightning MNIST example.\n", "\n", " Args:\n", " data_dir : The path to the directory where the MNIST dataset is stored. Defaults to the value of\n", " the 'PATH_DATASETS' environment variable or '.' if not set.\n", "\n", " save_dir : The path to the directory where the training logs will be saved. Defaults to 'logs/'.\n", "\n", " batch_size : The batch size to use during training. Defaults to 256 if a GPU is available,\n", " or 64 otherwise.\n", "\n", " max_epochs : The maximum number of epochs to train the model for. Defaults to 3.\n", "\n", " accelerator : The accelerator to use for training. Can be one of \"cpu\", \"gpu\", \"tpu\", \"ipu\", \"auto\".\n", "\n", " devices : The number of devices to use for training. Defaults to 1.\n", "\n", " Examples:\n", " This dataclass can be used to specify the configuration options for training a PyTorch Lightning model on the\n", " MNIST dataset. A new instance of this dataclass can be created as follows:\n", "\n", " >>> config = Config()\n", "\n", " The default values for each argument are shown in the documentation above. If desired, any of these values can be\n", " overridden when creating a new instance of the dataclass:\n", "\n", " >>> config = Config(batch_size=128, max_epochs=5)\n", "\n", " \"\"\"\n", "\n", " data_dir: str = os.environ.get(\"PATH_DATASETS\", \".\")\n", " save_dir: str = \"logs/\"\n", " batch_size: int = 256 if torch.cuda.is_available() else 64\n", " max_epochs: int = 3\n", " accelerator: str = \"auto\"\n", " devices: int = 1\n", "\n", "\n", "config = Config()"]}, {"cell_type": "markdown", "id": "b7fa54a6", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.005546, "end_time": "2025-04-03T20:53:29.347310", "exception": false, "start_time": "2025-04-03T20:53:29.341764", "status": "completed"}, "tags": []}, "source": ["## Simplest example\n", "\n", "Here's the simplest most minimal example with just a training loop (no validation, no testing).\n", "\n", "**Keep in Mind** - A `LightningModule` *is* a PyTorch `nn.Module` - it just has a few more helpful features."]}, {"cell_type": "code", "execution_count": 3, "id": "8e611193", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:53:29.360014Z", "iopub.status.busy": "2025-04-03T20:53:29.359609Z", "iopub.status.idle": "2025-04-03T20:53:29.371985Z", "shell.execute_reply": "2025-04-03T20:53:29.370820Z"}, "papermill": {"duration": 0.021379, "end_time": "2025-04-03T20:53:29.374211", "exception": false, "start_time": "2025-04-03T20:53:29.352832", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "\n", "class MNISTModel(pl.LightningModule):\n", " \"\"\"A PyTorch Lightning module for classifying images in the MNIST dataset.\n", "\n", " Attributes:\n", " l1 : A linear layer that maps input features to output features.\n", "\n", " Methods:\n", " forward(x):\n", " Performs a forward pass through the model.\n", "\n", " training_step(batch, batch_nb):\n", " Defines a single training step for the model.\n", "\n", " configure_optimizers():\n", " Configures the optimizer to use during training.\n", "\n", " Examples:\n", " The MNISTModel class can be used to create and train a PyTorch Lightning model for classifying images in the MNIST\n", " dataset. To create a new instance of the model, simply instantiate the class:\n", "\n", " >>> model = MNISTModel()\n", "\n", " The model can then be trained using a PyTorch Lightning trainer object:\n", "\n", " >>> trainer = pl.Trainer()\n", " >>> trainer.fit(model)\n", "\n", " \"\"\"\n", "\n", " def __init__(self):\n", " \"\"\"Initializes a new instance of the MNISTModel class.\"\"\"\n", " super().__init__()\n", " self.l1 = torch.nn.Linear(28 * 28, 10)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Performs a forward pass through the model.\n", "\n", " Args:\n", " x : The input tensor to pass through the model.\n", "\n", " Returns:\n", " activated : The output tensor produced by the model.\n", "\n", " Examples:\n", " >>> model = MNISTModel()\n", " >>> x = torch.randn(1, 1, 28, 28)\n", " >>> output = model(x)\n", "\n", " \"\"\"\n", " flattened = x.view(x.size(0), -1)\n", " hidden = self.l1(flattened)\n", " activated = torch.relu(hidden)\n", "\n", " return activated\n", "\n", " def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> torch.Tensor:\n", " \"\"\"Defines a single training step for the model.\n", "\n", " Args:\n", " batch: A tuple containing the input and target tensors for the batch.\n", " batch_nb: The batch number.\n", "\n", " Returns:\n", " torch.Tensor: The loss value for the current batch.\n", "\n", " Examples:\n", " >>> model = MNISTModel()\n", " >>> x = torch.randn(1, 1, 28, 28)\n", " >>> y = torch.tensor([1])\n", " >>> loss = model.training_step((x, y), 0)\n", "\n", " \"\"\"\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", " return loss\n", "\n", " def configure_optimizers(self) -> torch.optim.Optimizer:\n", " \"\"\"Configures the optimizer to use during training.\n", "\n", " Returns:\n", " torch.optim.Optimizer: The optimizer to use during training.\n", "\n", " Examples:\n", " >>> model = MNISTModel()\n", " >>> optimizer = model.configure_optimizers()\n", "\n", " \"\"\"\n", " return torch.optim.Adam(self.parameters(), lr=0.02)"]}, {"cell_type": "markdown", "id": "70f96910", "metadata": {"papermill": {"duration": 0.005677, "end_time": "2025-04-03T20:53:29.385581", "exception": false, "start_time": "2025-04-03T20:53:29.379904", "status": "completed"}, "tags": []}, "source": ["By using the `Trainer` you automatically get:\n", "1. Tensorboard logging\n", "2. Model checkpointing\n", "3. Training and validation loop\n", "4. early-stopping"]}, {"cell_type": "code", "execution_count": 4, "id": "36648472", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:53:29.399010Z", "iopub.status.busy": "2025-04-03T20:53:29.397956Z", "iopub.status.idle": "2025-04-03T20:53:48.559987Z", "shell.execute_reply": "2025-04-03T20:53:48.558794Z"}, "papermill": {"duration": 19.171314, "end_time": "2025-04-03T20:53:48.562520", "exception": false, "start_time": "2025-04-03T20:53:29.391206", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Failed to download (trying next):\n", "HTTP Error 404: Not Found\n", "\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /__w/11/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/9912422 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 7%|\u258b | 720896/9912422 [00:00<00:01, 6751932.08it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 47%|\u2588\u2588\u2588\u2588\u258b | 4653056/9912422 [00:00<00:00, 25281479.35it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 81%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u258f | 8060928/9912422 [00:00<00:00, 29215494.01it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 9912422/9912422 [00:00<00:00, 27468042.31it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/11/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz to /__w/11/s/.datasets/MNIST/raw\n"]}, {"name": "stdout", "output_type": "stream", "text": ["\n", "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", "Failed to download (trying next):\n", "HTTP Error 404: Not Found\n", "\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /__w/11/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/28881 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 28881/28881 [00:00<00:00, 1461720.41it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/11/s/.datasets/MNIST/raw/train-labels-idx1-ubyte.gz to /__w/11/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", "Failed to download (trying next):\n", "HTTP Error 404: Not Found\n", "\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /__w/11/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/1648877 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 48%|\u2588\u2588\u2588\u2588\u258a | 786432/1648877 [00:00<00:00, 7213728.10it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 1648877/1648877 [00:00<00:00, 10605731.41it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/11/s/.datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to /__w/11/s/.datasets/MNIST/raw\n", "\n", "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", "Failed to download (trying next):\n", "HTTP Error 404: Not Found\n", "\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /__w/11/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/4542 [00:00, ?it/s]"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 4542/4542 [00:00<00:00, 4702673.11it/s]"]}, {"name": "stdout", "output_type": "stream", "text": ["Extracting /__w/11/s/.datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to /__w/11/s/.datasets/MNIST/raw\n", "\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", "GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n"]}, {"name": "stderr", "output_type": "stream", "text": ["You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params | Mode \n", "----------------------------------------\n", "0 | l1 | Linear | 7.9 K | train\n", "----------------------------------------\n", "7.9 K Trainable params\n", "0 Non-trainable params\n", "7.9 K Total params\n", "0.031 Total estimated model params size (MB)\n", "1 Modules in train mode\n", "0 Modules in eval mode\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "b78a6bd2f4fd46a283c97b7770ca1691", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Init our model\n", "mnist_model = MNISTModel()\n", "\n", "# Init DataLoader from MNIST Dataset\n", "train_ds = MNIST(config.data_dir, train=True, download=True, transform=transforms.ToTensor())\n", "\n", "# Create a dataloader\n", "train_loader = DataLoader(train_ds, batch_size=config.batch_size)\n", "\n", "# Initialize a trainer\n", "trainer = pl.Trainer(\n", " accelerator=config.accelerator,\n", " devices=config.devices,\n", " max_epochs=config.max_epochs,\n", ")\n", "\n", "# Train the model \u26a1\n", "trainer.fit(mnist_model, train_loader)"]}, {"cell_type": "markdown", "id": "194b7969", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.008596, "end_time": "2025-04-03T20:53:48.584480", "exception": false, "start_time": "2025-04-03T20:53:48.575884", "status": "completed"}, "tags": []}, "source": ["## A more complete MNIST Lightning Module Example\n", "\n", "That wasn't so hard was it?\n", "\n", "Now that we've got our feet wet, let's dive in a bit deeper and write a more complete `LightningModule` for MNIST...\n", "\n", "This time, we'll bake in all the dataset specific pieces directly in the `LightningModule`.\n", "This way, we can avoid writing extra code at the beginning of our script every time we want to run it.\n", "\n", "---\n", "\n", "### Note what the following built-in functions are doing:\n", "\n", "1. [prepare_data()](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#prepare-data) \ud83d\udcbe\n", " - This is where we can download the dataset. We point to our desired dataset and ask torchvision's `MNIST` dataset class to download if the dataset isn't found there.\n", " - **Note we do not make any state assignments in this function** (i.e. `self.something = ...`)\n", "\n", "2. [setup(stage)](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#setup) \u2699\ufe0f\n", " - Loads in data from file and prepares PyTorch tensor datasets for each split (train, val, test).\n", " - Setup expects a 'stage' arg which is used to separate logic for 'fit' and 'test'.\n", " - If you don't mind loading all your datasets at once, you can set up a condition to allow for both 'fit' related setup and 'test' related setup to run whenever `None` is passed to `stage` (or ignore it altogether and exclude any conditionals).\n", " - **Note this runs across all GPUs and it *is* safe to make state assignments here**\n", "\n", "3. [x_dataloader()](https://lightning.ai/docs/pytorch/stable/api/pytorch_lightning.core.hooks.DataHooks.html#pytorch_lightning.core.hooks.DataHooks.train_dataloader) \u267b\ufe0f\n", " - `train_dataloader()`, `val_dataloader()`, and `test_dataloader()` all return PyTorch `DataLoader` instances that are created by wrapping their respective datasets that we prepared in `setup()`"]}, {"cell_type": "code", "execution_count": 5, "id": "cb60ef5e", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:53:48.604190Z", "iopub.status.busy": "2025-04-03T20:53:48.603956Z", "iopub.status.idle": "2025-04-03T20:53:48.622508Z", "shell.execute_reply": "2025-04-03T20:53:48.621491Z"}, "papermill": {"duration": 0.030936, "end_time": "2025-04-03T20:53:48.624004", "exception": false, "start_time": "2025-04-03T20:53:48.593068", "status": "completed"}, "tags": []}, "outputs": [], "source": ["\n", "\n", "class LitMNIST(pl.LightningModule):\n", " \"\"\"PyTorch Lightning module for training a multi-layer perceptron (MLP) on the MNIST dataset.\n", "\n", " Attributes:\n", " data_dir : The path to the directory where the MNIST data will be downloaded.\n", "\n", " hidden_size : The number of units in the hidden layer of the MLP.\n", "\n", " learning_rate : The learning rate to use for training the MLP.\n", "\n", " Methods:\n", " forward(x):\n", " Performs a forward pass through the MLP.\n", "\n", " training_step(batch, batch_idx):\n", " Defines a single training step for the MLP.\n", "\n", " validation_step(batch, batch_idx):\n", " Defines a single validation step for the MLP.\n", "\n", " test_step(batch, batch_idx):\n", " Defines a single testing step for the MLP.\n", "\n", " configure_optimizers():\n", " Configures the optimizer to use for training the MLP.\n", "\n", " prepare_data():\n", " Downloads the MNIST dataset.\n", "\n", " setup(stage=None):\n", " Splits the MNIST dataset into train, validation, and test sets.\n", "\n", " train_dataloader():\n", " Returns a DataLoader for the training set.\n", "\n", " val_dataloader():\n", " Returns a DataLoader for the validation set.\n", "\n", " test_dataloader():\n", " Returns a DataLoader for the test set.\n", "\n", " \"\"\"\n", "\n", " def __init__(self, data_dir: str = config.data_dir, hidden_size: int = 64, learning_rate: float = 2e-4):\n", " \"\"\"Initializes a new instance of the LitMNIST class.\n", "\n", " Args:\n", " data_dir : The path to the directory where the MNIST data will be downloaded. Defaults to config.data_dir.\n", "\n", " hidden_size : The number of units in the hidden layer of the MLP (default is 64).\n", "\n", " learning_rate : The learning rate to use for training the MLP (default is 2e-4).\n", "\n", " \"\"\"\n", " super().__init__()\n", "\n", " # Set our init args as class attributes\n", " self.data_dir = data_dir\n", " self.hidden_size = hidden_size\n", " self.learning_rate = learning_rate\n", "\n", " # Hardcode some dataset specific attributes\n", " self.num_classes = 10\n", " self.dims = (1, 28, 28)\n", " channels, width, height = self.dims\n", "\n", " self.transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " transforms.Normalize((0.1307,), (0.3081,)),\n", " ]\n", " )\n", "\n", " # Define PyTorch model\n", " self.model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(channels * width * height, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, hidden_size),\n", " nn.ReLU(),\n", " nn.Dropout(0.1),\n", " nn.Linear(hidden_size, self.num_classes),\n", " )\n", "\n", " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", " self.test_accuracy = Accuracy(task=\"multiclass\", num_classes=10)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"Performs a forward pass through the MLP.\n", "\n", " Args:\n", " x : The input data.\n", "\n", " Returns:\n", " torch.Tensor: The output of the MLP.\n", "\n", " \"\"\"\n", " x = self.model(x)\n", " return F.log_softmax(x, dim=1)\n", "\n", " def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> torch.Tensor:\n", " \"\"\"Defines a single training step for the MLP.\n", "\n", " Args:\n", " batch: A tuple containing the input data and target labels.\n", "\n", " batch_idx: The index of the current batch.\n", "\n", " Returns:\n", " (torch.Tensor): The training loss.\n", "\n", " \"\"\"\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " return loss\n", "\n", " def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> None:\n", " \"\"\"Defines a single validation step for the MLP.\n", "\n", " Args:\n", " batch : A tuple containing the input data and target labels.\n", " batch_idx : The index of the current batch.\n", "\n", " \"\"\"\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.val_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log(\"val_acc\", self.val_accuracy, prog_bar=True)\n", "\n", " def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_nb: int) -> None:\n", " \"\"\"Defines a single testing step for the MLP.\n", "\n", " Args:\n", " batch : A tuple containing the input data and target labels.\n", " batch_idx : The index of the current batch.\n", "\n", " \"\"\"\n", " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " self.test_accuracy.update(preds, y)\n", "\n", " # Calling self.log will surface up scalars for you in TensorBoard\n", " self.log(\"test_loss\", loss, prog_bar=True)\n", " self.log(\"test_acc\", self.test_accuracy, prog_bar=True)\n", "\n", " def configure_optimizers(self) -> torch.optim.Optimizer:\n", " \"\"\"Configures the optimizer to use for training the MLP.\n", "\n", " Returns:\n", " torch.optim.Optimizer: The optimizer.\n", "\n", " \"\"\"\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", "\n", " return optimizer\n", "\n", " # ------------------------------------- #\n", " # DATA RELATED HOOKS\n", " # ------------------------------------- #\n", "\n", " def prepare_data(self) -> None:\n", " \"\"\"Downloads the MNIST dataset.\"\"\"\n", " MNIST(self.data_dir, train=True, download=True)\n", "\n", " MNIST(self.data_dir, train=False, download=True)\n", "\n", " def setup(self, stage: str = None) -> None:\n", " \"\"\"Splits the MNIST dataset into train, validation, and test sets.\n", "\n", " Args:\n", " stage : The current stage (either \"fit\" or \"test\"). Defaults to None.\n", "\n", " \"\"\"\n", " # Assign train/val datasets for use in dataloaders\n", " if stage == \"fit\" or stage is None:\n", " mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)\n", "\n", " self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])\n", "\n", " # Assign test dataset for use in dataloader(s)\n", " if stage == \"test\" or stage is None:\n", " self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)\n", "\n", " def train_dataloader(self) -> DataLoader:\n", " \"\"\"Returns a DataLoader for the training set.\n", "\n", " Returns:\n", " DataLoader: The training DataLoader.\n", "\n", " \"\"\"\n", " return DataLoader(self.mnist_train, batch_size=config.batch_size)\n", "\n", " def val_dataloader(self) -> DataLoader:\n", " \"\"\"Returns a DataLoader for the validation set.\n", "\n", " Returns:\n", " DataLoader: The validation DataLoader.\n", "\n", " \"\"\"\n", " return DataLoader(self.mnist_val, batch_size=config.batch_size)\n", "\n", " def test_dataloader(self) -> DataLoader:\n", " \"\"\"Returns a DataLoader for the test set.\n", "\n", " Returns:\n", " DataLoader: The test DataLoader.\n", "\n", " \"\"\"\n", " return DataLoader(self.mnist_test, batch_size=config.batch_size)"]}, {"cell_type": "code", "execution_count": 6, "id": "dd757f00", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:53:48.636025Z", "iopub.status.busy": "2025-04-03T20:53:48.635786Z", "iopub.status.idle": "2025-04-03T20:54:17.240844Z", "shell.execute_reply": "2025-04-03T20:54:17.239903Z"}, "papermill": {"duration": 28.613771, "end_time": "2025-04-03T20:54:17.243279", "exception": false, "start_time": "2025-04-03T20:53:48.629508", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params | Mode \n", "-------------------------------------------------------------\n", "0 | model | Sequential | 55.1 K | train\n", "1 | val_accuracy | MulticlassAccuracy | 0 | train\n", "2 | test_accuracy | MulticlassAccuracy | 0 | train\n", "-------------------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n", "11 Modules in train mode\n", "0 Modules in eval mode\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "b06e7dc4754045cab50324d13c347f42", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "beed699aeea54e46a47611de51b4cd5b", "version_major": 2, "version_minor": 0}, "text/plain": ["Training: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "e4745f2eb188404cbc71cfd57f3e6197", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "15c1b64638fa46e49df46c6cb49bd5fe", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "394994819a6e43609e10a6e94d53874a", "version_major": 2, "version_minor": 0}, "text/plain": ["Validation: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["# Instantiate the LitMNIST model\n", "model = LitMNIST()\n", "\n", "# Instantiate a PyTorch Lightning trainer with the specified configuration\n", "trainer = pl.Trainer(\n", " accelerator=config.accelerator,\n", " devices=config.devices,\n", " max_epochs=config.max_epochs,\n", " logger=CSVLogger(save_dir=config.save_dir),\n", ")\n", "\n", "# Train the model using the trainer\n", "trainer.fit(model)"]}, {"cell_type": "markdown", "id": "4a41fbb7", "metadata": {"papermill": {"duration": 0.010619, "end_time": "2025-04-03T20:54:17.266435", "exception": false, "start_time": "2025-04-03T20:54:17.255816", "status": "completed"}, "tags": []}, "source": ["### Testing\n", "\n", "To test a model, call `trainer.test(model)`.\n", "\n", "Or, if you've just trained a model, you can just call `trainer.test()` and Lightning will automatically\n", "test using the best saved checkpoint (conditioned on val_loss)."]}, {"cell_type": "code", "execution_count": 7, "id": "329707b6", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:54:17.288897Z", "iopub.status.busy": "2025-04-03T20:54:17.288704Z", "iopub.status.idle": "2025-04-03T20:54:18.815398Z", "shell.execute_reply": "2025-04-03T20:54:18.814478Z"}, "papermill": {"duration": 1.540523, "end_time": "2025-04-03T20:54:18.817528", "exception": false, "start_time": "2025-04-03T20:54:17.277005", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Restoring states from the checkpoint path at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Loaded model weights from the checkpoint at logs/lightning_logs/version_0/checkpoints/epoch=2-step=645.ckpt\n"]}, {"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "408fb8c049e24343b6efe7c735c4fc2c", "version_major": 2, "version_minor": 0}, "text/plain": ["Testing: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", " Test metric DataLoader 0\n", "\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", " test_acc 0.9257000088691711\n", " test_loss 0.2555711567401886\n", "\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n"]}, {"data": {"text/plain": ["[{'test_loss': 0.2555711567401886, 'test_acc': 0.9257000088691711}]"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["trainer.test(ckpt_path=\"best\")"]}, {"cell_type": "markdown", "id": "715a74c9", "metadata": {"papermill": {"duration": 0.011285, "end_time": "2025-04-03T20:54:18.840340", "exception": false, "start_time": "2025-04-03T20:54:18.829055", "status": "completed"}, "tags": []}, "source": ["### Bonus Tip\n", "\n", "You can keep calling `trainer.fit(model)` as many times as you'd like to continue training"]}, {"cell_type": "code", "execution_count": 8, "id": "9960ca85", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:54:18.862477Z", "iopub.status.busy": "2025-04-03T20:54:18.862243Z", "iopub.status.idle": "2025-04-03T20:54:19.137231Z", "shell.execute_reply": "2025-04-03T20:54:19.136436Z"}, "papermill": {"duration": 0.287741, "end_time": "2025-04-03T20:54:19.139280", "exception": false, "start_time": "2025-04-03T20:54:18.851539", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory logs/lightning_logs/version_0/checkpoints exists and is not empty.\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\n", " | Name | Type | Params | Mode \n", "-------------------------------------------------------------\n", "0 | model | Sequential | 55.1 K | train\n", "1 | val_accuracy | MulticlassAccuracy | 0 | train\n", "2 | test_accuracy | MulticlassAccuracy | 0 | train\n", "-------------------------------------------------------------\n", "55.1 K Trainable params\n", "0 Non-trainable params\n", "55.1 K Total params\n", "0.220 Total estimated model params size (MB)\n", "11 Modules in train mode\n", "0 Modules in eval mode\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "8f2d61af170941f4951c1b56d3a41064", "version_major": 2, "version_minor": 0}, "text/plain": ["Sanity Checking: | | 0/? [00:00, ?it/s]"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stderr", "output_type": "stream", "text": ["`Trainer.fit` stopped: `max_epochs=3` reached.\n"]}], "source": ["trainer.fit(model)"]}, {"cell_type": "markdown", "id": "48ac84c1", "metadata": {"papermill": {"duration": 0.011711, "end_time": "2025-04-03T20:54:19.163188", "exception": false, "start_time": "2025-04-03T20:54:19.151477", "status": "completed"}, "tags": []}, "source": ["In Colab, you can use the TensorBoard magic function to view the logs that Lightning has created for you!"]}, {"cell_type": "code", "execution_count": 9, "id": "be95608c", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T20:54:19.186251Z", "iopub.status.busy": "2025-04-03T20:54:19.186052Z", "iopub.status.idle": "2025-04-03T20:54:19.466496Z", "shell.execute_reply": "2025-04-03T20:54:19.465499Z"}, "papermill": {"duration": 0.293674, "end_time": "2025-04-03T20:54:19.468780", "exception": false, "start_time": "2025-04-03T20:54:19.175106", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["
\n", " | test_acc | \n", "test_loss | \n", "val_acc | \n", "val_loss | \n", "
---|---|---|---|---|
epoch | \n", "\n", " | \n", " | \n", " | \n", " |
0 | \n", "NaN | \n", "NaN | \n", "0.8790 | \n", "0.442527 | \n", "
1 | \n", "NaN | \n", "NaN | \n", "0.9066 | \n", "0.326151 | \n", "
2 | \n", "NaN | \n", "NaN | \n", "0.9190 | \n", "0.283037 | \n", "
3 | \n", "0.9257 | \n", "0.255571 | \n", "NaN | \n", "NaN | \n", "