{"cells": [{"cell_type": "markdown", "id": "945cfe5e", "metadata": {"papermill": {"duration": 0.009871, "end_time": "2025-04-03T19:23:06.058462", "exception": false, "start_time": "2025-04-03T19:23:06.048591", "status": "completed"}, "tags": []}, "source": ["\n", "# Tutorial 7: Deep Energy-Based Generative Models\n", "\n", "* **Author:** Phillip Lippe\n", "* **License:** CC BY-SA\n", "* **Generated:** 2025-04-03T19:22:59.369888\n", "\n", "In this tutorial, we will look at energy-based deep learning models, and focus on their application as generative models.\n", "Energy models have been a popular tool before the huge deep learning hype around 2012 hit.\n", "However, in recent years, energy-based models have gained increasing attention because of improved training methods and tricks being proposed.\n", "Although they are still in a research stage, they have shown to outperform strong Generative Adversarial Networks\n", "in certain cases which have been the state of the art of generating images\n", "([blog post](https://ajolicoeur.wordpress.com/the-new-contender-to-gans-score-matching-with-langevin-sampling/)about strong energy-based models,\n", "[blog post](https://medium.com/syncedreview/nvidia-open-sources-hyper-realistic-face-generator-stylegan-f346e1a73826) about the power of GANs).\n", "Hence, it is important to be aware of energy-based models, and as the theory can be abstract sometimes,\n", "we will show the idea of energy-based models with a lot of examples.\n", "This notebook is part of a lecture series on Deep Learning at the University of Amsterdam.\n", "The full list of tutorials can be found at https://uvadlc-notebooks.rtfd.io.\n", "\n", "\n", "---\n", "Open in [![Open In Colab](){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/course_UvA-DL/07-deep-energy-based-generative-models.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://lightning.ai/docs/)\n", "| Join us [on Discord](https://discord.com/invite/tfXFetEZxv)"]}, {"cell_type": "markdown", "id": "4a407e72", "metadata": {"papermill": {"duration": 0.008896, "end_time": "2025-04-03T19:23:06.075785", "exception": false, "start_time": "2025-04-03T19:23:06.066889", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "6d4ee18f", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2025-04-03T19:23:06.093383Z", "iopub.status.busy": "2025-04-03T19:23:06.093150Z", "iopub.status.idle": "2025-04-03T19:23:07.291349Z", "shell.execute_reply": "2025-04-03T19:23:07.289976Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 1.209946, "end_time": "2025-04-03T19:23:07.293929", "exception": false, "start_time": "2025-04-03T19:23:06.083983", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}, {"name": "stdout", "output_type": "stream", "text": ["\r\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.2\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0.1\u001b[0m\r\n", "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\r\n"]}], "source": ["! pip install --quiet \"numpy <3.0\" \"tensorboard\" \"torchmetrics >=1.0,<1.8\" \"pytorch-lightning >=2.0,<2.6\" \"torch >=1.8.1,<2.7\" \"torchvision\" \"seaborn\" \"matplotlib\""]}, {"cell_type": "markdown", "id": "c05c196e", "metadata": {"papermill": {"duration": 0.013464, "end_time": "2025-04-03T19:23:07.321630", "exception": false, "start_time": "2025-04-03T19:23:07.308166", "status": "completed"}, "tags": []}, "source": ["
\n", "First, let's import our standard libraries below."]}, {"cell_type": "code", "execution_count": 2, "id": "41544bd4", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:07.349988Z", "iopub.status.busy": "2025-04-03T19:23:07.349610Z", "iopub.status.idle": "2025-04-03T19:23:10.624170Z", "shell.execute_reply": "2025-04-03T19:23:10.622828Z"}, "papermill": {"duration": 3.291496, "end_time": "2025-04-03T19:23:10.626613", "exception": false, "start_time": "2025-04-03T19:23:07.335117", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Seed set to 42\n"]}], "source": ["# Standard libraries\n", "import os\n", "import random\n", "import urllib.request\n", "from urllib.error import HTTPError\n", "\n", "# Plotting\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", "import matplotlib_inline.backend_inline\n", "import numpy as np\n", "\n", "# PyTorch Lightning\n", "import pytorch_lightning as pl\n", "\n", "# PyTorch\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.utils.data as data\n", "\n", "# Torchvision\n", "import torchvision\n", "from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint\n", "from torchvision import transforms\n", "from torchvision.datasets import MNIST\n", "\n", "matplotlib_inline.backend_inline.set_matplotlib_formats(\"svg\", \"pdf\") # For export\n", "matplotlib.rcParams[\"lines.linewidth\"] = 2.0\n", "\n", "# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)\n", "DATASET_PATH = os.environ.get(\"PATH_DATASETS\", \"data\")\n", "# Path to the folder where the pretrained models are saved\n", "CHECKPOINT_PATH = os.environ.get(\"PATH_CHECKPOINT\", \"saved_models/tutorial8\")\n", "\n", "# Setting the seed\n", "pl.seed_everything(42)\n", "\n", "# Ensure that all operations are deterministic on GPU (if used) for reproducibility\n", "torch.backends.cudnn.deterministic = True\n", "torch.backends.cudnn.benchmark = False\n", "\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")"]}, {"cell_type": "markdown", "id": "d2310e95", "metadata": {"papermill": {"duration": 0.013654, "end_time": "2025-04-03T19:23:10.654933", "exception": false, "start_time": "2025-04-03T19:23:10.641279", "status": "completed"}, "tags": []}, "source": ["We also have pre-trained models that we download below."]}, {"cell_type": "code", "execution_count": 3, "id": "62398548", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:10.682508Z", "iopub.status.busy": "2025-04-03T19:23:10.682091Z", "iopub.status.idle": "2025-04-03T19:23:11.018724Z", "shell.execute_reply": "2025-04-03T19:23:11.017381Z"}, "papermill": {"duration": 0.352597, "end_time": "2025-04-03T19:23:11.021206", "exception": false, "start_time": "2025-04-03T19:23:10.668609", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial8/MNIST.ckpt...\n", "Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial8/tensorboards/events.out.tfevents.MNIST...\n"]}], "source": ["# Github URL where saved models are stored for this tutorial\n", "base_url = \"https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial8/\"\n", "# Files to download\n", "pretrained_files = [\"MNIST.ckpt\", \"tensorboards/events.out.tfevents.MNIST\"]\n", "\n", "# Create checkpoint path if it doesn't exist yet\n", "os.makedirs(CHECKPOINT_PATH, exist_ok=True)\n", "\n", "# For each file, check whether it already exists. If not, try downloading it.\n", "for file_name in pretrained_files:\n", " file_path = os.path.join(CHECKPOINT_PATH, file_name)\n", " if \"/\" in file_name:\n", " os.makedirs(file_path.rsplit(\"/\", 1)[0], exist_ok=True)\n", " if not os.path.isfile(file_path):\n", " file_url = base_url + file_name\n", " print(f\"Downloading {file_url}...\")\n", " try:\n", " urllib.request.urlretrieve(file_url, file_path)\n", " except HTTPError as e:\n", " print(\n", " \"Something went wrong. Please try to download the files manually,\"\n", " \" or contact the author with the full output including the following error:\\n\",\n", " e,\n", " )"]}, {"cell_type": "markdown", "id": "215f5e9b", "metadata": {"papermill": {"duration": 0.013807, "end_time": "2025-04-03T19:23:11.049110", "exception": false, "start_time": "2025-04-03T19:23:11.035303", "status": "completed"}, "tags": []}, "source": ["## Energy Models\n", "\n", "In the first part of this tutorial, we will review the theory of the energy-based models\n", "(the same theory has been discussed in Lecture 8).\n", "While most of the previous models had the goal of classification or regression,\n", "energy-based models are motivated from a different perspective: density estimation.\n", "Given a dataset with a lot of elements, we want to estimate the probability distribution over the whole data space.\n", "As an example, if we model images from CIFAR10, our goal would be to have a probability distribution\n", "over all possible images of size $32\\times32\\times3$ where those images have a high likelihood\n", "that look realistic and are one of the 10 CIFAR classes.\n", "Simple methods like interpolation between images don't work because images are extremely high-dimensional\n", "(especially for large HD images).\n", "Hence, we turn to deep learning methods that have performed well on complex data.\n", "\n", "However, how do we predict a probability distribution $p(\\mathbf{x})$ over so many dimensions using a simple neural network?\n", "The problem is that we cannot just predict a score between 0 and 1,\n", "because a probability distribution over data needs to fulfill two properties:\n", "\n", "1.\n", "The probability distribution needs to assign any possible value of\n", "$\\mathbf{x}$ a non-negative value: $p(\\mathbf{x}) \\geq 0$.\n", "2.\n", "The probability density must sum/integrate to 1 over **all** possible inputs:\n", "$\\int_{\\mathbf{x}} p(\\mathbf{x}) d\\mathbf{x} = 1$.\n", "\n", "Luckily, there are actually many approaches for this, and one of them are energy-based models.\n", "The fundamental idea of energy-based models is that you can turn any function\n", "that predicts values larger than zero into a probability distribution by dviding by its volume.\n", "Imagine we have a neural network, which has as output a single neuron, like in regression.\n", "We can call this network $E_{\\theta}(\\mathbf{x})$, where $\\theta$ are our parameters of the network,\n", "and $\\mathbf{x}$ the input data (e.g. an image).\n", "The output of $E_{\\theta}$ is a scalar value between $-\\infty$ and $\\infty$.\n", "Now, we can use basic probability theory to *normalize* the scores of all possible inputs:\n", "\n", "$$\n", "q_{\\theta}(\\mathbf{x}) = \\frac{\\exp\\left(-E_{\\theta}(\\mathbf{x})\\right)}{Z_{\\theta}} \\hspace{5mm}\\text{where}\\hspace{5mm}\n", "Z_{\\theta} = \\begin{cases}\n", " \\int_{\\mathbf{x}}\\exp\\left(-E_{\\theta}(\\mathbf{x})\\right) d\\mathbf{x} & \\text{if }x\\text{ is continuous}\\\\\n", " \\sum_{\\mathbf{x}}\\exp\\left(-E_{\\theta}(\\mathbf{x})\\right) & \\text{if }x\\text{ is discrete}\n", "\\end{cases}\n", "$$\n", "\n", "The $\\exp$-function ensures that we assign a probability greater than zero to any possible input.\n", "We use a negative sign in front of $E$ because we call $E_{\\theta}$ to be the energy function:\n", "data points with high likelihood have a low energy, while data points with low likelihood have a high energy.\n", "$Z_{\\theta}$ is our normalization terms that ensures that the density integrates/sums to 1.\n", "We can show this by integrating over $q_{\\theta}(\\mathbf{x})$:\n", "\n", "$$\n", "\\int_{\\mathbf{x}}q_{\\theta}(\\mathbf{x})d\\mathbf{x} =\n", "\\int_{\\mathbf{x}}\\frac{\\exp\\left(-E_{\\theta}(\\mathbf{x})\\right)}{\\int_{\\mathbf{\\tilde{x}}}\\exp\\left(-E_{\\theta}(\\mathbf{\\tilde{x}})\\right) d\\mathbf{\\tilde{x}}}d\\mathbf{x} =\n", "\\frac{\\int_{\\mathbf{x}}\\exp\\left(-E_{\\theta}(\\mathbf{x})\\right)d\\mathbf{x}}{\\int_{\\mathbf{\\tilde{x}}}\\exp\\left(-E_{\\theta}(\\mathbf{\\tilde{x}})\\right) d\\mathbf{\\tilde{x}}} = 1\n", "$$\n", "\n", "Note that we call the probability distribution $q_{\\theta}(\\mathbf{x})$ because this is the learned distribution by the model,\n", "and is trained to be as close as possible to the *true*, unknown distribution $p(\\mathbf{x})$.\n", "\n", "The main benefit of this formulation of the probability distribution is its great flexibility as we can choose\n", "$E_{\\theta}$ in whatever way we like, without any constraints.\n", "Nevertheless, when looking at the equation above, we can see a fundamental issue: How do we calculate $Z_{\\theta}$?\n", "There is no chance that we can calculate $Z_{\\theta}$ analytically for high-dimensional input\n", "and/or larger neural networks, but the task requires us to know $Z_{\\theta}$.\n", "Although we can't determine the exact likelihood of a point, there exist methods with which we can train energy-based models.\n", "Thus, we will look next at \"Contrastive Divergence\" for training the model."]}, {"cell_type": "markdown", "id": "f41a374b", "metadata": {"papermill": {"duration": 0.011494, "end_time": "2025-04-03T19:23:11.074370", "exception": false, "start_time": "2025-04-03T19:23:11.062876", "status": "completed"}, "tags": []}, "source": ["### Contrastive Divergence\n", "\n", "When we train a model on generative modeling, it is usually done by maximum likelihood estimation.\n", "In other words, we try to maximize the likelihood of the examples in the training set.\n", "As the exact likelihood of a point cannot be determined due to the unknown normalization constant $Z_{\\theta}$,\n", "we need to train energy-based models slightly different.\n", "We cannot just maximize the un-normalized probability $\\exp(-E_{\\theta}(\\mathbf{x}_{\\text{train}}))$\n", "because there is no guarantee that $Z_{\\theta}$ stays constant, or that $\\mathbf{x}_{\\text{train}}$\n", "is becoming more likely than the others.\n", "However, if we base our training on comparing the likelihood of points, we can create a stable objective.\n", "Namely, we can re-write our maximum likelihood objective where we maximize the probability\n", "of $\\mathbf{x}_{\\text{train}}$ compared to a randomly sampled data point of our model:\n", "\n", "$$\n", "\\begin{split}\n", " \\nabla_{\\theta}\\mathcal{L}_{\\text{MLE}}(\\mathbf{\\theta};p) & = -\\mathbb{E}_{p(\\mathbf{x})}\\left[\\nabla_{\\theta}\\log q_{\\theta}(\\mathbf{x})\\right]\\\\[5pt]\n", " & = \\mathbb{E}_{p(\\mathbf{x})}\\left[\\nabla_{\\theta}E_{\\theta}(\\mathbf{x})\\right] - \\mathbb{E}_{q_{\\theta}(\\mathbf{x})}\\left[\\nabla_{\\theta}E_{\\theta}(\\mathbf{x})\\right]\n", "\\end{split}\n", "$$\n", "\n", "Note that the loss is still an objective we want to minimize.\n", "Thus, we try to minimize the energy for data points from the dataset, while maximizing the energy for randomly\n", "sampled data points from our model (how we sample will be explained below).\n", "Although this objective sounds intuitive, how is it actually derived from our original distribution $q_{\\theta}(\\mathbf{x})$?\n", "The trick is that we approximate $Z_{\\theta}$ by a single Monte-Carlo sample.\n", "This gives us the exact same objective as written above.\n", "\n", "Visually, we can look at the objective as follows (figure credit - Stefano Ermon and Aditya Grover: lecture cs236/11):\n", "\n", "
\n", "\n", "$f_{\\theta}$ represents $\\exp(-E_{\\theta}(\\mathbf{x}))$ in our case.\n", "The point on the right, called \"correct answer\", represents a data point from the dataset\n", "(i.e. $x_{\\text{train}}$), and the left point, \"wrong answer\", a sample from our model (i.e. $x_{\\text{sample}}$).\n", "Thus, we try to \"pull up\" the probability of the data points in the dataset,\n", "while \"pushing down\" randomly sampled points.\n", "The two forces for pulling and pushing are in balance iff $q_{\\theta}(\\mathbf{x})=p(\\mathbf{x})$."]}, {"cell_type": "markdown", "id": "dfff9b15", "metadata": {"papermill": {"duration": 0.008414, "end_time": "2025-04-03T19:23:11.091227", "exception": false, "start_time": "2025-04-03T19:23:11.082813", "status": "completed"}, "tags": []}, "source": ["### Sampling from Energy-Based Models\n", "\n", "For sampling from an energy-based model, we can apply a Markov Chain Monte Carlo using Langevin Dynamics.\n", "The idea of the algorithm is to start from a random point, and slowly move towards the direction\n", "of higher probability using the gradients of $E_{\\theta}$.\n", "Nevertheless, this is not enough to fully capture the probability distribution.\n", "We need to add noise $\\omega$ at each gradient step to the current sample.\n", "Under certain conditions such as that we perform the gradient steps an infinite amount of times,\n", "we would be able to create an exact sample from our modeled distribution.\n", "However, as this is not practically possible, we usually limit the chain to $K$ steps\n", "($K$ a hyperparameter that needs to be finetuned).\n", "Overall, the sampling procedure can be summarized in the following algorithm:\n", "\n", "
"]}, {"cell_type": "markdown", "id": "b2feca54", "metadata": {"papermill": {"duration": 0.007102, "end_time": "2025-04-03T19:23:11.106731", "exception": false, "start_time": "2025-04-03T19:23:11.099629", "status": "completed"}, "tags": []}, "source": ["### Applications of Energy-based models beyond generation\n", "\n", "Modeling the probability distribution for sampling new data is not the only application of energy-based models.\n", "Any application which requires us to compare two elements is much simpler to learn\n", "because we just need to go for the higher energy.\n", "A couple of examples are shown below (figure credit - Stefano Ermon and Aditya Grover: lecture cs236/11).\n", "A classification setup like object recognition or sequence labeling can be considered as an energy-based\n", "task as we just need to find the $Y$ input that minimizes the output $E(X, Y)$ (hence maximizes probability).\n", "Similarly, a popular application of energy-based models is denoising of images.\n", "Given an image $X$ with a lot of noise, we try to minimize the energy by finding the true input image $Y$.\n", "\n", "
\n", "\n", "Nonetheless, we will focus on generative modeling here as in the next couple of lectures,\n", "we will discuss more generative deep learning approaches."]}, {"cell_type": "markdown", "id": "cb7b6d5a", "metadata": {"papermill": {"duration": 0.006077, "end_time": "2025-04-03T19:23:11.118923", "exception": false, "start_time": "2025-04-03T19:23:11.112846", "status": "completed"}, "tags": []}, "source": ["## Image generation\n", "\n", "
\n", "\n", "As an example for energy-based models, we will train a model on image generation.\n", "Specifically, we will look at how we can generate MNIST digits with a very simple CNN model.\n", "However, it should be noted that energy models are not easy to train and often diverge\n", "if the hyperparameters are not well tuned.\n", "We will rely on training tricks proposed in the paper\n", "[Implicit Generation and Generalization in Energy-Based Models](https://arxiv.org/abs/1903.08689)\n", "by Yilun Du and Igor Mordatch ([blog](https://openai.com/index/energy-based-models/)).\n", "The important part of this notebook is however to see how the theory above can actually be used in a model.\n", "\n", "### Dataset\n", "\n", "First, we can load the MNIST dataset below.\n", "Note that we need to normalize the images between -1 and 1 instead of mean 0 and std 1 because during sampling,\n", "we have to limit the input space.\n", "Scaling between -1 and 1 makes it easier to implement it."]}, {"cell_type": "code", "execution_count": 4, "id": "c45a5c65", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:11.132897Z", "iopub.status.busy": "2025-04-03T19:23:11.132510Z", "iopub.status.idle": "2025-04-03T19:23:13.673293Z", "shell.execute_reply": "2025-04-03T19:23:13.672305Z"}, "papermill": {"duration": 2.54975, "end_time": "2025-04-03T19:23:13.674790", "exception": false, "start_time": "2025-04-03T19:23:11.125040", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", "Failed to download (trying next):\n", "HTTP Error 404: Not Found\n", "\n", "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /__w/11/s/.datasets/MNIST/raw/train-images-idx3-ubyte.gz\n"]}, {"name": "stderr", "output_type": "stream", "text": ["\r", " 0%| | 0/9912422 [00:00 make them a tensor and normalize between -1 and 1\n", "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])\n", "\n", "# Loading the training dataset. We need to split it into a training and validation part\n", "train_set = MNIST(root=DATASET_PATH, train=True, transform=transform, download=True)\n", "\n", "# Loading the test set\n", "test_set = MNIST(root=DATASET_PATH, train=False, transform=transform, download=True)\n", "\n", "# We define a set of data loaders that we can use for various purposes later.\n", "# Note that for actually training a model, we will use different data loaders\n", "# with a lower batch size.\n", "train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)\n", "test_loader = data.DataLoader(test_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)"]}, {"cell_type": "markdown", "id": "803ca18b", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.009937, "end_time": "2025-04-03T19:23:13.695308", "exception": false, "start_time": "2025-04-03T19:23:13.685371", "status": "completed"}, "tags": []}, "source": ["### CNN Model\n", "\n", "First, we implement our CNN model.\n", "The MNIST images are of size 28x28, hence we only need a small model.\n", "As an example, we will apply several convolutions with stride 2 that downscale the images.\n", "If you are interested, you can also use a deeper model such as a small ResNet, but for simplicity,\n", "we will stick with the tiny network.\n", "\n", "It is a good practice to use a smooth activation function like Swish instead of ReLU in the energy model.\n", "This is because we will rely on the gradients we get back with respect to the input image, which should not be sparse."]}, {"cell_type": "code", "execution_count": 5, "id": "36984381", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:13.711322Z", "iopub.status.busy": "2025-04-03T19:23:13.710621Z", "iopub.status.idle": "2025-04-03T19:23:13.718616Z", "shell.execute_reply": "2025-04-03T19:23:13.717722Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.017411, "end_time": "2025-04-03T19:23:13.719873", "exception": false, "start_time": "2025-04-03T19:23:13.702462", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class CNNModel(nn.Module):\n", " def __init__(self, hidden_features=32, out_dim=1, **kwargs):\n", " super().__init__()\n", " # We increase the hidden dimension over layers. Here pre-calculated for simplicity.\n", " c_hid1 = hidden_features // 2\n", " c_hid2 = hidden_features\n", " c_hid3 = hidden_features * 2\n", "\n", " # Series of convolutions and Swish activation functions\n", " self.cnn_layers = nn.Sequential(\n", " nn.Conv2d(1, c_hid1, kernel_size=5, stride=2, padding=4), # [16x16] - Larger padding to get 32x32 image\n", " nn.SiLU(),\n", " nn.Conv2d(c_hid1, c_hid2, kernel_size=3, stride=2, padding=1), # [8x8]\n", " nn.SiLU(),\n", " nn.Conv2d(c_hid2, c_hid3, kernel_size=3, stride=2, padding=1), # [4x4]\n", " nn.SiLU(),\n", " nn.Conv2d(c_hid3, c_hid3, kernel_size=3, stride=2, padding=1), # [2x2]\n", " nn.SiLU(),\n", " nn.Flatten(),\n", " nn.Linear(c_hid3 * 4, c_hid3),\n", " nn.SiLU(),\n", " nn.Linear(c_hid3, out_dim),\n", " )\n", "\n", " def forward(self, x):\n", " x = self.cnn_layers(x).squeeze(dim=-1)\n", " return x"]}, {"cell_type": "markdown", "id": "72b153d0", "metadata": {"papermill": {"duration": 0.007136, "end_time": "2025-04-03T19:23:13.734208", "exception": false, "start_time": "2025-04-03T19:23:13.727072", "status": "completed"}, "tags": []}, "source": ["In the rest of the notebook, the output of the model will actually not represent\n", "$E_{\\theta}(\\mathbf{x})$, but $-E_{\\theta}(\\mathbf{x})$.\n", "This is a standard implementation practice for energy-based models, as some people also write the energy probability\n", "density as $q_{\\theta}(\\mathbf{x}) = \\frac{\\exp\\left(f_{\\theta}(\\mathbf{x})\\right)}{Z_{\\theta}}$.\n", "In that case, the model would actually represent $f_{\\theta}(\\mathbf{x})$.\n", "In the training loss etc., we need to be careful to not switch up the signs."]}, {"cell_type": "markdown", "id": "d6f0ccd9", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007107, "end_time": "2025-04-03T19:23:13.748459", "exception": false, "start_time": "2025-04-03T19:23:13.741352", "status": "completed"}, "tags": []}, "source": ["### Sampling buffer\n", "\n", "In the next part, we look at the training with sampled elements.\n", "To use the contrastive divergence objective, we need to generate samples during training.\n", "Previous work has shown that due to the high dimensionality of images, we need a lot of iterations\n", "inside the MCMC sampling to obtain reasonable samples.\n", "However, there is a training trick that significantly reduces the sampling cost: using a sampling buffer.\n", "The idea is that we store the samples of the last couple of batches in a buffer,\n", "and reuse those as the starting point of the MCMC algorithm for the next batches.\n", "This reduces the sampling cost because the model requires a significantly\n", "lower number of steps to converge to reasonable samples.\n", "However, to not solely rely on previous samples and allow novel samples as well,\n", "we re-initialize 5% of our samples from scratch (random noise between -1 and 1).\n", "\n", "Below, we implement the sampling buffer.\n", "The function `sample_new_exmps` returns a new batch of \"fake\" images.\n", "We refer to those as fake images because they have been generated, but are not actually part of the dataset.\n", "As mentioned before, we use initialize 5% randomly, and 95% are randomly picked from our buffer.\n", "On this initial batch, we perform MCMC for 60 iterations to improve the image quality\n", "and come closer to samples from $q_{\\theta}(\\mathbf{x})$.\n", "In the function `generate_samples`, we implemented the MCMC for images.\n", "Note that the hyperparameters of `step_size`, `steps`, the noise standard deviation\n", "$\\sigma$ are specifically set for MNIST, and need to be finetuned for a different dataset if you want to use such."]}, {"cell_type": "code", "execution_count": 6, "id": "11bbcb96", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:13.764020Z", "iopub.status.busy": "2025-04-03T19:23:13.763702Z", "iopub.status.idle": "2025-04-03T19:23:13.783928Z", "shell.execute_reply": "2025-04-03T19:23:13.782778Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.029715, "end_time": "2025-04-03T19:23:13.785348", "exception": false, "start_time": "2025-04-03T19:23:13.755633", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class Sampler:\n", " def __init__(self, model, img_shape, sample_size, max_len=8192):\n", " \"\"\"Sampler.\n", "\n", " Args:\n", " model: Neural network to use for modeling E_theta\n", " img_shape: Shape of the images to model\n", " sample_size: Batch size of the samples\n", " max_len: Maximum number of data points to keep in the buffer\n", "\n", " \"\"\"\n", " super().__init__()\n", " self.model = model\n", " self.img_shape = img_shape\n", " self.sample_size = sample_size\n", " self.max_len = max_len\n", " self.examples = [(torch.rand((1,) + img_shape) * 2 - 1) for _ in range(self.sample_size)]\n", "\n", " def sample_new_exmps(self, steps=60, step_size=10):\n", " \"\"\"Function for getting a new batch of \"fake\" images.\n", "\n", " Args:\n", " steps: Number of iterations in the MCMC algorithm\n", " step_size: Learning rate nu in the algorithm above\n", "\n", " \"\"\"\n", " # Choose 95% of the batch from the buffer, 5% generate from scratch\n", " n_new = np.random.binomial(self.sample_size, 0.05)\n", " rand_imgs = torch.rand((n_new,) + self.img_shape) * 2 - 1\n", " old_imgs = torch.cat(random.choices(self.examples, k=self.sample_size - n_new), dim=0)\n", " inp_imgs = torch.cat([rand_imgs, old_imgs], dim=0).detach().to(device)\n", "\n", " # Perform MCMC sampling\n", " inp_imgs = Sampler.generate_samples(self.model, inp_imgs, steps=steps, step_size=step_size)\n", "\n", " # Add new images to the buffer and remove old ones if needed\n", " self.examples = list(inp_imgs.to(torch.device(\"cpu\")).chunk(self.sample_size, dim=0)) + self.examples\n", " self.examples = self.examples[: self.max_len]\n", " return inp_imgs\n", "\n", " @staticmethod\n", " def generate_samples(model, inp_imgs, steps=60, step_size=10, return_img_per_step=False):\n", " \"\"\"Function for sampling images for a given model.\n", "\n", " Args:\n", " model: Neural network to use for modeling E_theta\n", " inp_imgs: Images to start from for sampling. If you want to generate new images, enter noise between -1 and 1.\n", " steps: Number of iterations in the MCMC algorithm.\n", " step_size: Learning rate nu in the algorithm above\n", " return_img_per_step: If True, we return the sample at every iteration of the MCMC\n", "\n", " \"\"\"\n", " # Before MCMC: set model parameters to \"required_grad=False\"\n", " # because we are only interested in the gradients of the input.\n", " is_training = model.training\n", " model.eval()\n", " for p in model.parameters():\n", " p.requires_grad = False\n", " inp_imgs.requires_grad = True\n", "\n", " # Enable gradient calculation if not already the case\n", " had_gradients_enabled = torch.is_grad_enabled()\n", " torch.set_grad_enabled(True)\n", "\n", " # We use a buffer tensor in which we generate noise each loop iteration.\n", " # More efficient than creating a new tensor every iteration.\n", " noise = torch.randn(inp_imgs.shape, device=inp_imgs.device)\n", "\n", " # List for storing generations at each step (for later analysis)\n", " imgs_per_step = []\n", "\n", " # Loop over K (steps)\n", " for _ in range(steps):\n", " # Part 1: Add noise to the input.\n", " noise.normal_(0, 0.005)\n", " inp_imgs.data.add_(noise.data)\n", " inp_imgs.data.clamp_(min=-1.0, max=1.0)\n", "\n", " # Part 2: calculate gradients for the current input.\n", " out_imgs = -model(inp_imgs)\n", " out_imgs.sum().backward()\n", " inp_imgs.grad.data.clamp_(-0.03, 0.03) # For stabilizing and preventing too high gradients\n", "\n", " # Apply gradients to our current samples\n", " inp_imgs.data.add_(-step_size * inp_imgs.grad.data)\n", " inp_imgs.grad.detach_()\n", " inp_imgs.grad.zero_()\n", " inp_imgs.data.clamp_(min=-1.0, max=1.0)\n", "\n", " if return_img_per_step:\n", " imgs_per_step.append(inp_imgs.clone().detach())\n", "\n", " # Reactivate gradients for parameters for training\n", " for p in model.parameters():\n", " p.requires_grad = True\n", " model.train(is_training)\n", "\n", " # Reset gradient calculation to setting before this function\n", " torch.set_grad_enabled(had_gradients_enabled)\n", "\n", " if return_img_per_step:\n", " return torch.stack(imgs_per_step, dim=0)\n", " else:\n", " return inp_imgs"]}, {"cell_type": "markdown", "id": "36de51b3", "metadata": {"papermill": {"duration": 0.010126, "end_time": "2025-04-03T19:23:13.802753", "exception": false, "start_time": "2025-04-03T19:23:13.792627", "status": "completed"}, "tags": []}, "source": ["The idea of the buffer becomes a bit clearer in the following algorithm."]}, {"cell_type": "markdown", "id": "7464a5c9", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007151, "end_time": "2025-04-03T19:23:13.817261", "exception": false, "start_time": "2025-04-03T19:23:13.810110", "status": "completed"}, "tags": []}, "source": ["### Training algorithm\n", "\n", "With the sampling buffer being ready, we can complete our training algorithm.\n", "Below is shown a summary of the full training algorithm of an energy model on image modeling:\n", "\n", "
\n", "\n", "The first few statements in each training iteration concern the sampling of the real and fake data,\n", "as we have seen above with the sample buffer.\n", "Next, we calculate the contrastive divergence objective using our energy model $E_{\\theta}$.\n", "However, one additional training trick we need is to add a regularization loss on the output of $E_{\\theta}$.\n", "As the output of the network is not constrained and adding a large bias or not to the output\n", "doesn't change the contrastive divergence loss, we need to ensure somehow else that the output values are in a reasonable range.\n", "Without the regularization loss, the output values will fluctuate in a very large range.\n", "With this, we ensure that the values for the real data are around 0, and the fake data likely slightly lower\n", "(for noise or outliers the score can be still significantly lower).\n", "As the regularization loss is less important than the Contrastive Divergence, we have a weight factor\n", "$\\alpha$ which is usually quite some smaller than 1.\n", "Finally, we perform an update step with an optimizer on the combined loss and add the new samples to the buffer.\n", "\n", "Below, we put this training dynamic into a PyTorch Lightning module:"]}, {"cell_type": "code", "execution_count": 7, "id": "32986a5c", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:13.863485Z", "iopub.status.busy": "2025-04-03T19:23:13.863129Z", "iopub.status.idle": "2025-04-03T19:23:13.879514Z", "shell.execute_reply": "2025-04-03T19:23:13.878400Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.026524, "end_time": "2025-04-03T19:23:13.880966", "exception": false, "start_time": "2025-04-03T19:23:13.854442", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class DeepEnergyModel(pl.LightningModule):\n", " def __init__(self, img_shape, batch_size, alpha=0.1, lr=1e-4, beta1=0.0, **CNN_args):\n", " super().__init__()\n", " self.save_hyperparameters()\n", "\n", " self.cnn = CNNModel(**CNN_args)\n", " self.sampler = Sampler(self.cnn, img_shape=img_shape, sample_size=batch_size)\n", " self.example_input_array = torch.zeros(1, *img_shape)\n", "\n", " def forward(self, x):\n", " z = self.cnn(x)\n", " return z\n", "\n", " def configure_optimizers(self):\n", " # Energy models can have issues with momentum as the loss surfaces changes with its parameters.\n", " # Hence, we set it to 0 by default.\n", " optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr, betas=(self.hparams.beta1, 0.999))\n", " scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.97) # Exponential decay over epochs\n", " return [optimizer], [scheduler]\n", "\n", " def training_step(self, batch, batch_idx):\n", " # We add minimal noise to the original images to prevent the model from focusing on purely \"clean\" inputs\n", " real_imgs, _ = batch\n", " small_noise = torch.randn_like(real_imgs) * 0.005\n", " real_imgs.add_(small_noise).clamp_(min=-1.0, max=1.0)\n", "\n", " # Obtain samples\n", " fake_imgs = self.sampler.sample_new_exmps(steps=60, step_size=10)\n", "\n", " # Predict energy score for all images\n", " inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)\n", " real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)\n", "\n", " # Calculate losses\n", " reg_loss = self.hparams.alpha * (real_out**2 + fake_out**2).mean()\n", " cdiv_loss = fake_out.mean() - real_out.mean()\n", " loss = reg_loss + cdiv_loss\n", "\n", " # Logging\n", " self.log(\"loss\", loss)\n", " self.log(\"loss_regularization\", reg_loss)\n", " self.log(\"loss_contrastive_divergence\", cdiv_loss)\n", " self.log(\"metrics_avg_real\", real_out.mean())\n", " self.log(\"metrics_avg_fake\", fake_out.mean())\n", " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " # For validating, we calculate the contrastive divergence between purely random images and unseen examples\n", " # Note that the validation/test step of energy-based models depends on what we are interested in the model\n", " real_imgs, _ = batch\n", " fake_imgs = torch.rand_like(real_imgs) * 2 - 1\n", "\n", " inp_imgs = torch.cat([real_imgs, fake_imgs], dim=0)\n", " real_out, fake_out = self.cnn(inp_imgs).chunk(2, dim=0)\n", "\n", " cdiv = fake_out.mean() - real_out.mean()\n", " self.log(\"val_contrastive_divergence\", cdiv)\n", " self.log(\"val_fake_out\", fake_out.mean())\n", " self.log(\"val_real_out\", real_out.mean())"]}, {"cell_type": "markdown", "id": "09ed0efc", "metadata": {"papermill": {"duration": 0.007182, "end_time": "2025-04-03T19:23:13.895419", "exception": false, "start_time": "2025-04-03T19:23:13.888237", "status": "completed"}, "tags": []}, "source": ["We do not implement a test step because energy-based, generative models are usually not evaluated on a test set.\n", "The validation step however is used to get an idea of the difference between ennergy/likelihood\n", "of random images to unseen examples of the dataset."]}, {"cell_type": "markdown", "id": "a7e35b14", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007192, "end_time": "2025-04-03T19:23:13.910001", "exception": false, "start_time": "2025-04-03T19:23:13.902809", "status": "completed"}, "tags": []}, "source": ["### Callbacks\n", "\n", "To track the performance of our model during training, we will make extensive use of PyTorch Lightning's callback framework.\n", "Remember that callbacks can be used for running small functions at any point of the training,\n", "for instance after finishing an epoch.\n", "Here, we will use three different callbacks we define ourselves.\n", "\n", "The first callback, called `GenerateCallback`, is used for adding image generations to the model during training.\n", "After every $N$ epochs (usually $N=5$ to reduce output to TensorBoard), we take a small batch\n", "of random images and perform many MCMC iterations until the model's generation converges.\n", "Compared to the training that used 60 iterations, we use 256 here because\n", "(1) we only have to do it once compared to the training that has to do it every iteration, and\n", "(2) we do not start from a buffer here, but from scratch.\n", "It is implemented as follows:"]}, {"cell_type": "code", "execution_count": 8, "id": "c6007f74", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:13.925886Z", "iopub.status.busy": "2025-04-03T19:23:13.925534Z", "iopub.status.idle": "2025-04-03T19:23:13.936945Z", "shell.execute_reply": "2025-04-03T19:23:13.935817Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.021136, "end_time": "2025-04-03T19:23:13.938375", "exception": false, "start_time": "2025-04-03T19:23:13.917239", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GenerateCallback(Callback):\n", " def __init__(self, batch_size=8, vis_steps=8, num_steps=256, every_n_epochs=5):\n", " super().__init__()\n", " self.batch_size = batch_size # Number of images to generate\n", " self.vis_steps = vis_steps # Number of steps within generation to visualize\n", " self.num_steps = num_steps # Number of steps to take during generation\n", " # Only save those images every N epochs (otherwise tensorboard gets quite large)\n", " self.every_n_epochs = every_n_epochs\n", "\n", " def on_epoch_end(self, trainer, pl_module):\n", " # Skip for all other epochs\n", " if trainer.current_epoch % self.every_n_epochs == 0:\n", " # Generate images\n", " imgs_per_step = self.generate_imgs(pl_module)\n", " # Plot and add to tensorboard\n", " for i in range(imgs_per_step.shape[1]):\n", " step_size = self.num_steps // self.vis_steps\n", " imgs_to_plot = imgs_per_step[step_size - 1 :: step_size, i]\n", " grid = torchvision.utils.make_grid(\n", " imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, value_range=(-1, 1)\n", " )\n", " trainer.logger.experiment.add_image(f\"generation_{i}\", grid, global_step=trainer.current_epoch)\n", "\n", " def generate_imgs(self, pl_module):\n", " pl_module.eval()\n", " start_imgs = torch.rand((self.batch_size,) + pl_module.hparams[\"img_shape\"]).to(pl_module.device)\n", " start_imgs = start_imgs * 2 - 1\n", " imgs_per_step = Sampler.generate_samples(\n", " pl_module.cnn, start_imgs, steps=self.num_steps, step_size=10, return_img_per_step=True\n", " )\n", " pl_module.train()\n", " return imgs_per_step"]}, {"cell_type": "markdown", "id": "ef3df449", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007238, "end_time": "2025-04-03T19:23:13.952912", "exception": false, "start_time": "2025-04-03T19:23:13.945674", "status": "completed"}, "tags": []}, "source": ["The second callback is called `SamplerCallback`, and simply adds a randomly picked subset of images\n", "in the sampling buffer to the TensorBoard.\n", "This helps to understand what images are currently shown to the model as \"fake\"."]}, {"cell_type": "code", "execution_count": 9, "id": "1d32b15b", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:13.968867Z", "iopub.status.busy": "2025-04-03T19:23:13.968517Z", "iopub.status.idle": "2025-04-03T19:23:13.976524Z", "shell.execute_reply": "2025-04-03T19:23:13.975393Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.017742, "end_time": "2025-04-03T19:23:13.977957", "exception": false, "start_time": "2025-04-03T19:23:13.960215", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class SamplerCallback(Callback):\n", " def __init__(self, num_imgs=32, every_n_epochs=5):\n", " super().__init__()\n", " self.num_imgs = num_imgs # Number of images to plot\n", " # Only save those images every N epochs (otherwise tensorboard gets quite large)\n", " self.every_n_epochs = every_n_epochs\n", "\n", " def on_epoch_end(self, trainer, pl_module):\n", " if trainer.current_epoch % self.every_n_epochs == 0:\n", " exmp_imgs = torch.cat(random.choices(pl_module.sampler.examples, k=self.num_imgs), dim=0)\n", " grid = torchvision.utils.make_grid(exmp_imgs, nrow=4, normalize=True, value_range=(-1, 1))\n", " trainer.logger.experiment.add_image(\"sampler\", grid, global_step=trainer.current_epoch)"]}, {"cell_type": "markdown", "id": "73246f33", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007291, "end_time": "2025-04-03T19:23:13.992568", "exception": false, "start_time": "2025-04-03T19:23:13.985277", "status": "completed"}, "tags": []}, "source": ["Finally, our last callback is `OutlierCallback`.\n", "This callback evaluates the model by recording the (negative) energy assigned to random noise.\n", "While our training loss is almost constant across iterations,\n", "this score is likely showing the progress of the model to detect \"outliers\"."]}, {"cell_type": "code", "execution_count": 10, "id": "4559d70a", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:14.008675Z", "iopub.status.busy": "2025-04-03T19:23:14.008325Z", "iopub.status.idle": "2025-04-03T19:23:14.016176Z", "shell.execute_reply": "2025-04-03T19:23:14.015143Z"}, "lines_to_next_cell": 2, "papermill": {"duration": 0.018164, "end_time": "2025-04-03T19:23:14.018064", "exception": false, "start_time": "2025-04-03T19:23:13.999900", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class OutlierCallback(Callback):\n", " def __init__(self, batch_size=1024):\n", " super().__init__()\n", " self.batch_size = batch_size\n", "\n", " def on_epoch_end(self, trainer, pl_module):\n", " with torch.no_grad():\n", " pl_module.eval()\n", " rand_imgs = torch.rand((self.batch_size,) + pl_module.hparams[\"img_shape\"]).to(pl_module.device)\n", " rand_imgs = rand_imgs * 2 - 1.0\n", " rand_out = pl_module.cnn(rand_imgs).mean()\n", " pl_module.train()\n", "\n", " trainer.logger.experiment.add_scalar(\"rand_out\", rand_out, global_step=trainer.current_epoch)"]}, {"cell_type": "markdown", "id": "46d6054d", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.007314, "end_time": "2025-04-03T19:23:14.032850", "exception": false, "start_time": "2025-04-03T19:23:14.025536", "status": "completed"}, "tags": []}, "source": ["### Running the model\n", "\n", "Finally, we can add everything together to create our final training function.\n", "The function is very similar to any other PyTorch Lightning training function we have seen so far.\n", "However, there is the small difference of that we do not test the model on a test set\n", "because we will analyse the model afterward by checking its prediction and ability to perform outlier detection."]}, {"cell_type": "code", "execution_count": 11, "id": "6674cc5b", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:14.048994Z", "iopub.status.busy": "2025-04-03T19:23:14.048640Z", "iopub.status.idle": "2025-04-03T19:23:14.057602Z", "shell.execute_reply": "2025-04-03T19:23:14.056491Z"}, "papermill": {"duration": 0.01879, "end_time": "2025-04-03T19:23:14.059001", "exception": false, "start_time": "2025-04-03T19:23:14.040211", "status": "completed"}, "tags": []}, "outputs": [], "source": ["def train_model(**kwargs):\n", " # Create a PyTorch Lightning trainer with the generation callback\n", " trainer = pl.Trainer(\n", " default_root_dir=os.path.join(CHECKPOINT_PATH, \"MNIST\"),\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=60,\n", " gradient_clip_val=0.1,\n", " callbacks=[\n", " ModelCheckpoint(save_weights_only=True, mode=\"min\", monitor=\"val_contrastive_divergence\"),\n", " GenerateCallback(every_n_epochs=5),\n", " SamplerCallback(every_n_epochs=5),\n", " OutlierCallback(),\n", " LearningRateMonitor(\"epoch\"),\n", " ],\n", " )\n", " # Check whether pretrained model exists. If yes, load it and skip training\n", " pretrained_filename = os.path.join(CHECKPOINT_PATH, \"MNIST.ckpt\")\n", " if os.path.isfile(pretrained_filename):\n", " print(\"Found pretrained model, loading...\")\n", " model = DeepEnergyModel.load_from_checkpoint(pretrained_filename)\n", " else:\n", " pl.seed_everything(42)\n", " model = DeepEnergyModel(**kwargs)\n", " trainer.fit(model, train_loader, test_loader)\n", " model = DeepEnergyModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)\n", " # No testing as we are more interested in other properties\n", " return model"]}, {"cell_type": "code", "execution_count": 12, "id": "ab5aa5ef", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:14.075219Z", "iopub.status.busy": "2025-04-03T19:23:14.074868Z", "iopub.status.idle": "2025-04-03T19:23:14.712453Z", "shell.execute_reply": "2025-04-03T19:23:14.711295Z"}, "papermill": {"duration": 0.648288, "end_time": "2025-04-03T19:23:14.714711", "exception": false, "start_time": "2025-04-03T19:23:14.066423", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["GPU available: True (cuda), used: True\n"]}, {"name": "stderr", "output_type": "stream", "text": ["TPU available: False, using: 0 TPU cores\n"]}, {"name": "stderr", "output_type": "stream", "text": ["HPU available: False, using: 0 HPUs\n"]}, {"name": "stdout", "output_type": "stream", "text": ["Found pretrained model, loading...\n"]}, {"name": "stderr", "output_type": "stream", "text": ["Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.4.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint saved_models/tutorial8/MNIST.ckpt`\n"]}], "source": ["model = train_model(img_shape=(1, 28, 28), batch_size=train_loader.batch_size, lr=1e-4, beta1=0.0)"]}, {"cell_type": "markdown", "id": "ce6dbdcf", "metadata": {"papermill": {"duration": 0.016957, "end_time": "2025-04-03T19:23:14.748772", "exception": false, "start_time": "2025-04-03T19:23:14.731815", "status": "completed"}, "tags": []}, "source": ["## Analysis\n", "\n", "In the last part of the notebook, we will try to take the trained energy-based generative model,\n", "and analyse its properties."]}, {"cell_type": "markdown", "id": "d3e9fa38", "metadata": {"papermill": {"duration": 0.010689, "end_time": "2025-04-03T19:23:14.775651", "exception": false, "start_time": "2025-04-03T19:23:14.764962", "status": "completed"}, "tags": []}, "source": ["### TensorBoard\n", "\n", "The first thing we can look at is the TensorBoard generate during training.\n", "This can help us to understand the training dynamic even better, and shows potential issues.\n", "Let's load the TensorBoard below:"]}, {"cell_type": "code", "execution_count": 13, "id": "1ba8f1ab", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:14.797808Z", "iopub.status.busy": "2025-04-03T19:23:14.797601Z", "iopub.status.idle": "2025-04-03T19:23:15.820450Z", "shell.execute_reply": "2025-04-03T19:23:15.819513Z"}, "papermill": {"duration": 1.035353, "end_time": "2025-04-03T19:23:15.821694", "exception": false, "start_time": "2025-04-03T19:23:14.786341", "status": "completed"}, "tags": []}, "outputs": [{"data": {"text/html": ["\n", " \n", " \n", " "], "text/plain": [""]}, "metadata": {}, "output_type": "display_data"}], "source": ["# Uncomment the following two lines to open a tensorboard in the notebook.\n", "# Adjust the path to your CHECKPOINT_PATH if needed.\n", "%load_ext tensorboard\n", "%tensorboard --logdir ../saved_models/tutorial8/tensorboards/"]}, {"cell_type": "markdown", "id": "34794d15", "metadata": {"papermill": {"duration": 0.007791, "end_time": "2025-04-03T19:23:15.837464", "exception": false, "start_time": "2025-04-03T19:23:15.829673", "status": "completed"}, "tags": []}, "source": ["
"]}, {"cell_type": "markdown", "id": "9d9df084", "metadata": {"papermill": {"duration": 0.007748, "end_time": "2025-04-03T19:23:15.853196", "exception": false, "start_time": "2025-04-03T19:23:15.845448", "status": "completed"}, "tags": []}, "source": ["We see that the contrastive divergence as well as the regularization converge quickly to 0.\n", "However, the training continues although the loss is always close to zero.\n", "This is because our \"training\" data changes with the model by sampling.\n", "The progress of training can be best measured by looking at the samples across iterations,\n", "and the score for random images that decreases constantly over time."]}, {"cell_type": "markdown", "id": "20e3fd9e", "metadata": {"papermill": {"duration": 0.007741, "end_time": "2025-04-03T19:23:15.868686", "exception": false, "start_time": "2025-04-03T19:23:15.860945", "status": "completed"}, "tags": []}, "source": ["### Image Generation\n", "\n", "Another way of evaluating generative models is by sampling a few generated images.\n", "Generative models need to be good at generating realistic images as this truly shows that they have modeled the true data distribution.\n", "Thus, let's sample a few images of the model below:"]}, {"cell_type": "code", "execution_count": 14, "id": "0f3a410d", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:15.886253Z", "iopub.status.busy": "2025-04-03T19:23:15.885377Z", "iopub.status.idle": "2025-04-03T19:23:16.785122Z", "shell.execute_reply": "2025-04-03T19:23:16.784107Z"}, "papermill": {"duration": 0.910721, "end_time": "2025-04-03T19:23:16.787132", "exception": false, "start_time": "2025-04-03T19:23:15.876411", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Seed set to 43\n"]}], "source": ["model.to(device)\n", "pl.seed_everything(43)\n", "callback = GenerateCallback(batch_size=4, vis_steps=8, num_steps=256)\n", "imgs_per_step = callback.generate_imgs(model)\n", "imgs_per_step = imgs_per_step.cpu()"]}, {"cell_type": "markdown", "id": "bc9e6eab", "metadata": {"papermill": {"duration": 0.00796, "end_time": "2025-04-03T19:23:16.803203", "exception": false, "start_time": "2025-04-03T19:23:16.795243", "status": "completed"}, "tags": []}, "source": ["The characteristic of sampling with energy-based models is that they require the iterative MCMC algorithm.\n", "To gain an insight in how the images change over iterations, we plot a few intermediate samples in the MCMC as well:"]}, {"cell_type": "code", "execution_count": 15, "id": "27d6130f", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:16.820471Z", "iopub.status.busy": "2025-04-03T19:23:16.820081Z", "iopub.status.idle": "2025-04-03T19:23:17.338777Z", "shell.execute_reply": "2025-04-03T19:23:17.337715Z"}, "papermill": {"duration": 0.529158, "end_time": "2025-04-03T19:23:17.340172", "exception": false, "start_time": "2025-04-03T19:23:16.811014", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:16.888018\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:17.068454\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:17.163799\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:17.271001\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}], "source": ["for i in range(imgs_per_step.shape[1]):\n", " step_size = callback.num_steps // callback.vis_steps\n", " imgs_to_plot = imgs_per_step[step_size - 1 :: step_size, i]\n", " imgs_to_plot = torch.cat([imgs_per_step[0:1, i], imgs_to_plot], dim=0)\n", " grid = torchvision.utils.make_grid(\n", " imgs_to_plot, nrow=imgs_to_plot.shape[0], normalize=True, value_range=(-1, 1), pad_value=0.5, padding=2\n", " )\n", " grid = grid.permute(1, 2, 0)\n", " plt.figure(figsize=(8, 8))\n", " plt.imshow(grid)\n", " plt.xlabel(\"Generation iteration\")\n", " plt.xticks(\n", " [(imgs_per_step.shape[-1] + 2) * (0.5 + j) for j in range(callback.vis_steps + 1)],\n", " labels=[1] + list(range(step_size, imgs_per_step.shape[0] + 1, step_size)),\n", " )\n", " plt.yticks([])\n", " plt.show()"]}, {"cell_type": "markdown", "id": "fe7e7d64", "metadata": {"papermill": {"duration": 0.014428, "end_time": "2025-04-03T19:23:17.370520", "exception": false, "start_time": "2025-04-03T19:23:17.356092", "status": "completed"}, "tags": []}, "source": ["We see that although starting from noise in the very first step, the sampling algorithm obtains reasonable shapes after only 32 steps.\n", "Over the next 200 steps, the shapes become clearer and changed towards realistic digits.\n", "The specific samples can differ when you run the code on Colab, hence the following description is specific to the plots shown on the website.\n", "The first row shows an 8, where we remove unnecessary white parts over iterations.\n", "The transformation across iterations can be seen at best for the second sample, which creates a digit of 2.\n", "While the first sample after 32 iterations looks a bit like a digit, but not really,\n", "the sample is transformed more and more to a typical image of the digit 2."]}, {"cell_type": "markdown", "id": "4b33d4b3", "metadata": {"papermill": {"duration": 0.014423, "end_time": "2025-04-03T19:23:17.398884", "exception": false, "start_time": "2025-04-03T19:23:17.384461", "status": "completed"}, "tags": []}, "source": ["### Out-of-distribution detection\n", "\n", "A very common and strong application of energy-based models is out-of-distribution detection\n", "(sometimes referred to as \"anomaly\" detection).\n", "As more and more deep learning models are applied in production and applications,\n", "a crucial aspect of these models is to know what the models don't know.\n", "Deep learning models are usually overconfident, meaning that they classify even random images sometimes with 100% probability.\n", "Clearly, this is not something that we want to see in applications.\n", "Energy-based models can help with this problem because they are trained to detect images that do not fit the training dataset distribution.\n", "Thus, in those applications, you could train an energy-based model along with the classifier,\n", "and only output predictions if the energy-based models assign a (unnormalized) probability higher than $\\delta$ to the image.\n", "You can actually combine classifiers and energy-based objectives in a single model,\n", "as proposed in this [paper](https://arxiv.org/abs/1912.03263).\n", "\n", "In this part of the analysis, we want to test the out-of-distribution capability of our energy-based model.\n", "Remember that a lower output of the model denotes a low probability.\n", "Thus, we hope to see low scores if we enter random noise to the model:"]}, {"cell_type": "code", "execution_count": 16, "id": "fce8c186", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:17.428680Z", "iopub.status.busy": "2025-04-03T19:23:17.428464Z", "iopub.status.idle": "2025-04-03T19:23:17.491413Z", "shell.execute_reply": "2025-04-03T19:23:17.490431Z"}, "papermill": {"duration": 0.079868, "end_time": "2025-04-03T19:23:17.492683", "exception": false, "start_time": "2025-04-03T19:23:17.412815", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Average score for random images: -17.878559112548828\n"]}], "source": ["with torch.no_grad():\n", " rand_imgs = torch.rand((128,) + model.hparams.img_shape).to(model.device)\n", " rand_imgs = rand_imgs * 2 - 1.0\n", " rand_out = model.cnn(rand_imgs).mean()\n", " print(f\"Average score for random images: {rand_out.item()}\")"]}, {"cell_type": "markdown", "id": "5990fa0f", "metadata": {"papermill": {"duration": 0.014063, "end_time": "2025-04-03T19:23:17.520975", "exception": false, "start_time": "2025-04-03T19:23:17.506912", "status": "completed"}, "tags": []}, "source": ["As we hoped, the model assigns very low probability to those noisy images.\n", "As another reference, let's look at predictions for a batch of images from the training set:"]}, {"cell_type": "code", "execution_count": 17, "id": "e7779be6", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:17.550127Z", "iopub.status.busy": "2025-04-03T19:23:17.549923Z", "iopub.status.idle": "2025-04-03T19:23:17.903348Z", "shell.execute_reply": "2025-04-03T19:23:17.902213Z"}, "papermill": {"duration": 0.370192, "end_time": "2025-04-03T19:23:17.905145", "exception": false, "start_time": "2025-04-03T19:23:17.534953", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["Average score for training images: -0.01\n"]}], "source": ["with torch.no_grad():\n", " train_imgs, _ = next(iter(train_loader))\n", " train_imgs = train_imgs.to(model.device)\n", " train_out = model.cnn(train_imgs).mean()\n", " print(f\"Average score for training images: {train_out.item():4.2f}\")"]}, {"cell_type": "markdown", "id": "42ead070", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.020048, "end_time": "2025-04-03T19:23:17.945686", "exception": false, "start_time": "2025-04-03T19:23:17.925638", "status": "completed"}, "tags": []}, "source": ["The scores are close to 0 because of the regularization objective that was added to the training.\n", "So clearly, the model can distinguish between noise and real digits.\n", "However, what happens if we change the training images a little, and see which ones gets a very low score?"]}, {"cell_type": "code", "execution_count": 18, "id": "5a4e8a71", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:17.981955Z", "iopub.status.busy": "2025-04-03T19:23:17.981542Z", "iopub.status.idle": "2025-04-03T19:23:17.992024Z", "shell.execute_reply": "2025-04-03T19:23:17.991070Z"}, "papermill": {"duration": 0.028035, "end_time": "2025-04-03T19:23:17.993346", "exception": false, "start_time": "2025-04-03T19:23:17.965311", "status": "completed"}, "tags": []}, "outputs": [], "source": ["@torch.no_grad()\n", "def compare_images(img1, img2):\n", " imgs = torch.stack([img1, img2], dim=0).to(model.device)\n", " score1, score2 = model.cnn(imgs).cpu().chunk(2, dim=0)\n", " grid = torchvision.utils.make_grid(\n", " [img1.cpu(), img2.cpu()], nrow=2, normalize=True, value_range=(-1, 1), pad_value=0.5, padding=2\n", " )\n", " grid = grid.permute(1, 2, 0)\n", " plt.figure(figsize=(4, 4))\n", " plt.imshow(grid)\n", " plt.xticks([(img1.shape[2] + 2) * (0.5 + j) for j in range(2)], labels=[\"Original image\", \"Transformed image\"])\n", " plt.yticks([])\n", " plt.show()\n", " print(f\"Score original image: {score1}\")\n", " print(f\"Score transformed image: {score2}\")"]}, {"cell_type": "markdown", "id": "3b544f22", "metadata": {"papermill": {"duration": 0.014267, "end_time": "2025-04-03T19:23:18.021845", "exception": false, "start_time": "2025-04-03T19:23:18.007578", "status": "completed"}, "tags": []}, "source": ["We use a random test image for this. Feel free to change it to experiment with the model yourself."]}, {"cell_type": "code", "execution_count": 19, "id": "4682c38f", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:18.052015Z", "iopub.status.busy": "2025-04-03T19:23:18.051650Z", "iopub.status.idle": "2025-04-03T19:23:18.422281Z", "shell.execute_reply": "2025-04-03T19:23:18.420665Z"}, "papermill": {"duration": 0.388112, "end_time": "2025-04-03T19:23:18.424359", "exception": false, "start_time": "2025-04-03T19:23:18.036247", "status": "completed"}, "tags": []}, "outputs": [], "source": ["test_imgs, _ = next(iter(test_loader))\n", "exmp_img = test_imgs[0].to(model.device)"]}, {"cell_type": "markdown", "id": "ca27f72b", "metadata": {"papermill": {"duration": 0.01431, "end_time": "2025-04-03T19:23:18.453377", "exception": false, "start_time": "2025-04-03T19:23:18.439067", "status": "completed"}, "tags": []}, "source": ["The first transformation is to add some random noise to the image:"]}, {"cell_type": "code", "execution_count": 20, "id": "0e973c65", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:18.484315Z", "iopub.status.busy": "2025-04-03T19:23:18.483899Z", "iopub.status.idle": "2025-04-03T19:23:18.595407Z", "shell.execute_reply": "2025-04-03T19:23:18.594325Z"}, "papermill": {"duration": 0.129024, "end_time": "2025-04-03T19:23:18.597017", "exception": false, "start_time": "2025-04-03T19:23:18.467993", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:18.544758\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Score original image: tensor([0.0304])\n", "Score transformed image: tensor([-0.0746])\n"]}], "source": ["img_noisy = exmp_img + torch.randn_like(exmp_img) * 0.3\n", "img_noisy.clamp_(min=-1.0, max=1.0)\n", "compare_images(exmp_img, img_noisy)"]}, {"cell_type": "markdown", "id": "d06b5110", "metadata": {"papermill": {"duration": 0.020627, "end_time": "2025-04-03T19:23:18.638559", "exception": false, "start_time": "2025-04-03T19:23:18.617932", "status": "completed"}, "tags": []}, "source": ["We can see that the score considerably drops.\n", "Hence, the model can detect random Gaussian noise on the image.\n", "This is also to expect as initially, the \"fake\" samples are pure noise images.\n", "\n", "Next, we flip an image and check how this influences the score:"]}, {"cell_type": "code", "execution_count": 21, "id": "88f1028e", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:18.674580Z", "iopub.status.busy": "2025-04-03T19:23:18.674315Z", "iopub.status.idle": "2025-04-03T19:23:18.771248Z", "shell.execute_reply": "2025-04-03T19:23:18.770121Z"}, "papermill": {"duration": 0.114667, "end_time": "2025-04-03T19:23:18.772878", "exception": false, "start_time": "2025-04-03T19:23:18.658211", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:18.729499\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Score original image: tensor([0.0304])\n", "Score transformed image: tensor([-0.0004])\n"]}], "source": ["img_flipped = exmp_img.flip(dims=(1, 2))\n", "compare_images(exmp_img, img_flipped)"]}, {"cell_type": "markdown", "id": "19958fcc", "metadata": {"papermill": {"duration": 0.021629, "end_time": "2025-04-03T19:23:18.817148", "exception": false, "start_time": "2025-04-03T19:23:18.795519", "status": "completed"}, "tags": []}, "source": ["If the digit can only be read in this way, for example, the 7, then we can see that the score drops.\n", "However, the score only drops slightly.\n", "This is likely because of the small size of our model.\n", "Keep in mind that generative modeling is a much harder task than classification,\n", "as we do not only need to distinguish between classes but learn **all** details/characteristics of the digits.\n", "With a deeper model, this could eventually be captured better (but at the cost of greater training instability).\n", "\n", "Finally, we check what happens if we reduce the digit significantly in size:"]}, {"cell_type": "code", "execution_count": 22, "id": "1f976db5", "metadata": {"execution": {"iopub.execute_input": "2025-04-03T19:23:18.853184Z", "iopub.status.busy": "2025-04-03T19:23:18.852980Z", "iopub.status.idle": "2025-04-03T19:23:18.929501Z", "shell.execute_reply": "2025-04-03T19:23:18.928554Z"}, "papermill": {"duration": 0.094483, "end_time": "2025-04-03T19:23:18.931027", "exception": false, "start_time": "2025-04-03T19:23:18.836544", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/pdf": "", "image/svg+xml": ["\n", "\n", "\n", " \n", " \n", " \n", " \n", " 2025-04-03T19:23:18.881241\n", " image/svg+xml\n", " \n", " \n", " Matplotlib v3.9.2, https://matplotlib.org/\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n"], "text/plain": ["
"]}, "metadata": {}, "output_type": "display_data"}, {"name": "stdout", "output_type": "stream", "text": ["Score original image: tensor([0.0304])\n", "Score transformed image: tensor([-0.0154])\n"]}], "source": ["img_tiny = torch.zeros_like(exmp_img) - 1\n", "img_tiny[:, exmp_img.shape[1] // 2 :, exmp_img.shape[2] // 2 :] = exmp_img[:, ::2, ::2]\n", "compare_images(exmp_img, img_tiny)"]}, {"cell_type": "markdown", "id": "9e36062c", "metadata": {"papermill": {"duration": 0.022054, "end_time": "2025-04-03T19:23:18.975198", "exception": false, "start_time": "2025-04-03T19:23:18.953144", "status": "completed"}, "tags": []}, "source": ["The score again drops but not by a large margin, although digits in the MNIST dataset usually are much larger.\n", "\n", "Overall, we can conclude that our model is good for detecting Gaussian noise and smaller transformations to existing digits.\n", "Nonetheless, to obtain a very good out-of-distribution model, we would need to train deeper models and for more iterations."]}, {"cell_type": "markdown", "id": "44c03299", "metadata": {"papermill": {"duration": 0.016248, "end_time": "2025-04-03T19:23:19.009791", "exception": false, "start_time": "2025-04-03T19:23:18.993543", "status": "completed"}, "tags": []}, "source": ["### Instability\n", "\n", "Finally, we should discuss the possible instabilities of energy-based models,\n", "in particular for the example of image generation that we have implemented in this notebook.\n", "In the process of hyperparameter search for this notebook, there have been several models that diverged.\n", "Divergence in energy-based models means that the models assign a high probability to examples of the training set which is a good thing.\n", "However, at the same time, the sampling algorithm fails and only generates noise images that obtain minimal probability scores.\n", "This happens because the model has created many local maxima in which the generated noise images fall.\n", "The energy surface over which we calculate the gradients to reach data points with high probability has \"diverged\" and is not useful for our MCMC sampling.\n", "\n", "Besides finding the optimal hyperparameters, a common trick in energy-based models is to reload stable checkpoints.\n", "If we detect that the model is diverging, we stop the training, load the model from one epoch ago where it did not diverge yet.\n", "Afterward, we continue training and hope that with a different seed the model is not diverging again.\n", "Nevertheless, this should be considered as the \"last hope\" for stabilizing the models,\n", "and careful hyperparameter tuning is the better way to do so.\n", "Sensitive hyperparameters include `step_size`, `steps` and the noise standard deviation in the sampler,\n", "and the learning rate and feature dimensionality in the CNN model."]}, {"cell_type": "markdown", "id": "720e54dc", "metadata": {"papermill": {"duration": 0.016346, "end_time": "2025-04-03T19:23:19.042366", "exception": false, "start_time": "2025-04-03T19:23:19.026020", "status": "completed"}, "tags": []}, "source": ["## Conclusion\n", "\n", "In this tutorial, we have discussed energy-based models for generative modeling.\n", "The concept relies on the idea that any strictly positive function can be turned into a probability\n", "distribution by normalizing over the whole dataset.\n", "As this is not reasonable to calculate for high dimensional data like images,\n", "we train the model using contrastive divergence and sampling via MCMC.\n", "While the idea allows us to turn any neural network into an energy-based model,\n", "we have seen that there are multiple training tricks needed to stabilize the training.\n", "Furthermore, the training time of these models is relatively long as, during every training iteration,\n", "we need to sample new \"fake\" images, even with a sampling buffer.\n", "In the next lectures and assignment, we will see different generative models (e.g. VAE, GAN, NF)\n", "that allow us to do generative modeling more stably, but with the cost of more parameters."]}, {"cell_type": "markdown", "id": "27a98b28", "metadata": {"papermill": {"duration": 0.01609, "end_time": "2025-04-03T19:23:19.074551", "exception": false, "start_time": "2025-04-03T19:23:19.058461", "status": "completed"}, "tags": []}, "source": ["## Congratulations - Time to Join the Community!\n", "\n", "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning\n", "movement, you can do so in the following ways!\n", "\n", "### Star [Lightning](https://github.com/Lightning-AI/lightning) on GitHub\n", "The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool\n", "tools we're building.\n", "\n", "### Join our [Discord](https://discord.com/invite/tfXFetEZxv)!\n", "The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself\n", "and share your interests in `#general` channel\n", "\n", "\n", "### Contributions !\n", "The best way to contribute to our community is to become a code contributor! At any time you can go to\n", "[Lightning](https://github.com/Lightning-AI/lightning) or [Bolt](https://github.com/Lightning-AI/lightning-bolts)\n", "GitHub Issues page and filter for \"good first issue\".\n", "\n", "* [Lightning good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* [Bolt good first issue](https://github.com/Lightning-AI/lightning-bolts/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n", "* You can also contribute your own notebooks with useful examples !\n", "\n", "### Great thanks from the entire Pytorch Lightning Team for your interest !\n", "\n", "[![Pytorch Lightning](){height=\"60px\" width=\"240px\"}](https://pytorchlightning.ai)"]}, {"cell_type": "raw", "metadata": {"raw_mimetype": "text/restructuredtext"}, "source": [".. customcarditem::\n", " :header: Tutorial 7: Deep Energy-Based Generative Models\n", " :card_description: In this tutorial, we will look at energy-based deep learning models, and focus on their application as generative models. Energy models have been a popular tool before the...\n", " :tags: Image,GPU/TPU,UvA-DL-Course\n", " :image: _static/images/course_UvA-DL/07-deep-energy-based-generative-models.jpg"]}], "metadata": {"jupytext": {"cell_metadata_filter": "colab,colab_type,id,-all", "formats": "ipynb,py:percent", "main_language": "python"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12"}, "papermill": {"default_parameters": {}, "duration": 16.797735, "end_time": "2025-04-03T19:23:22.011023", "environment_variables": {}, "exception": null, "input_path": "course_UvA-DL/07-deep-energy-based-generative-models/notebook.ipynb", "output_path": ".notebooks/course_UvA-DL/07-deep-energy-based-generative-models.ipynb", "parameters": {}, "start_time": "2025-04-03T19:23:05.213288", "version": "2.6.0"}}, "nbformat": 4, "nbformat_minor": 5}