{"cells": [{"cell_type": "markdown", "id": "d65ee16b", "metadata": {"papermill": {"duration": 0.0032, "end_time": "2025-05-01T12:07:39.423944", "exception": false, "start_time": "2025-05-01T12:07:39.420744", "status": "completed"}, "tags": []}, "source": ["\n", "# Finetune Transformers Models with PyTorch Lightning\n", "\n", "* **Author:** Lightning.ai\n", "* **License:** CC BY-SA\n", "* **Generated:** 2025-05-01T12:07:32.322089\n", "\n", "This notebook will use HuggingFace's `datasets` library to get data, which will be wrapped in a `LightningDataModule`.\n", "Then, we write a class to perform text classification on any dataset from the [GLUE Benchmark](https://gluebenchmark.com/).\n", "(We just show CoLA and MRPC due to constraint on compute/disk)\n", "\n", "\n", "---\n", "Open in [![Open In Colab](data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAHUAAAAUCAYAAACzrHJDAAAIuUlEQVRoQ+1ZaVRURxb+qhdolmbTUVSURpZgmLhHbQVFZIlGQBEXcMvJhKiTEzfigjQg7oNEJ9GMGidnjnNMBs2czIzajksEFRE1xklCTKJiQLRFsUGkoUWw+82pamn79etGYoKek1B/4NW99/tu3e/dquJBAGD27NkHALxKf39WY39gyrOi+i3xqGtUoePJrFmznrmgtModorbTu8YRNZk5cybXTvCtwh7o6NR2KzuZMWNGh6jtVt7nA0ymT5/eJlF9POrh7PAQl6s8bGYa3PUum//htmebVtLRqW0q01M5keTk5FZFzU0oRle3+zxwg5Hgtb+PZiL/ZVohxCI+hL5JgjmfjPxZ26+33BG3dA+ealHPM4gQAo5rU59gsI8bRvl54t3Ca62mvHyUAhtOlLd5WSQpKcluBjumnoCLs1EARkVd9E8l3p9y2i7RbQ1B6pFwu/YDgW8KbHJHMTQrwnjz2oZm9M4pavOCfo5jWrgCaaMVcMs6/pNhDr0+AMN93XlxV7R6DNpyzi7W/OE+yIrsjU6rTrbKV5cd/pNyItOmTbMp6sbBB+EqaYJY4cWE3VUciNt1TpgfcRFv71Fi54xT5kSoyLvOBEJMOMxWXkFlBeBSX4u6Zkcs+3KszYRtiapbNRqF31UgetVuc8z9vBXIv1qD+F1f83B6uDlCUyfsZGepGPpmg01OB7EITQbhS9ribKy+DmP1DUiClLz4bnIHVOqa7BY+Z1wg5g3zgUvyehiNpnJKxSLc/ts76LKm0BzX3c0RNy1yXjDcB5lWoro4iNHQxM+f1kWeWQARAWQS++trISJTp061Kep25X/MycwtjuctSC5rxo7ppi7VNUox5+PhPHtrsS2O1qJ6yx1QujQUzm9sh6hbkBlvvGcN8hYnwjUjH6kjfZEd5c/jitz5Jc5U3ENnFynKl4eB7nyEgP2UZ+Yz3/rVEbyYr27qELrtC4FIC0J7sc7xWnmccdHfRRTs0VB+cA4lt+oFcRR/wUeH8FG5w2Mbx8FQ8TXEvv1xYf4wBP3O2WyL3/UVjpXWgIqaFeUPr+wTmDvUB7njH6/bOv+HRg4SqioAg5GDe1aB3ZeMTJkyRSBqkLsWqSEm0fZVBEN94zEZnYvrdx1JL5cxe+a+AbhSJecRRHW/ikTFRTa38dtQlNZ5CRKwFvUtZU/kvBoEF9Uxni/XqIM+dwKbTw3rhcxIf7gmr2M+H6SMwx8iBzJbw5oxeG3Lv5FX9B3AGaHPS8e8z77H7v9VMpvPG5ug1enh7eGK8h0LBTwUb+GInqzInlRUK65DmTPQu4c3+uQKjwKK77zwUxBX4Tq7yR1RuiwUsqlrABCM6esHdXoy47fk4+prYKy8ZF574x4V5BnHQBuf4g9Z9ld8U36L2aktZNNplNfw7zotwWTy5MkCUft4aLEopJj5/OPHl1BQqeAVOnHgNSQOqmBzq9V9cfEm/yx5ubMGKS9cYPZ3vx2OS/c6PVHUuUO7Y1Pci3BO/1zgq18byebfGemLtNF+6JRtOvMk926ibussZqM+1mNz4TWkH7rCbM5phwGRGDAaoF8fY5OHFnlldAA8sgoEXKnDukA1NgSeNjqkJT9brbN4pC9WRweYXyLugR73c+MYvyWfu0yC6+mjzN1Isfw3FKJS98CU/zI1IHFkFPR52cHL2FJk0sB6kMTERIGo9GzcPkLNfA0cwdwi/hfEYO86ZMd9w+y1egfM2T2Eh/vesMNwljSzuZRT420SW3eqy8N6aHMmwmnFUZ7/PGVPbIoNZvNU1BURdHs0bT2+HjL8sDSM2e6vi4Lj5NW8WOLVA6RTT2azxLV+bglaFNqLieqemS/gWkw7NyoAHo+2dEsiivengjKsPFoqWOvbSh/kxPaxyW/JRzH2Fl3EzD9/xjAefJqB3usKUFn/0Gb+S/d/jy3FN2yLOmnSJJtn6oehByEiHPSeXnDxFGPRnoFoaBJjcdQlbDwcjL1zTNuQpoxD7R0OG0uUTMi0fkVwdzBdYIwcwZunxrVJVLplNm54BZp7jfDfYLoNyqQi1K6KxIdHzmN+QQ2WjFIwUT2zTGdlRXo4NFXVUO4sgX5dFC7f0aP/ZlNeUjFBuL8Xjl6uRuP6aMjSjpjzsH62FDU7JhBuGccEXIvDfJFFBc/gHw80dklfCVYnRaDfpiJcutPA4F7qJsfJeUPQI+1fqMlNhFx1FM0GDqkjFVg7NojlQ0Vt4aM5ReSqcbpaCg8nCW5lRsBvbT4T1TLfFptsfh7gItzuKTdJSEiwKSrt1vcmnEXXrsLbYnWDA1bu+z2WKy9Arq+1KRqdfKsoBo0GcdtEpS/B1bO4v0cFiUhkjskvKcMrWwtAPHuwQq8Z+4LZ1vTQANfXt4J0DwZX9gWa9qh4XDM/voC9JXfwYEMMHJcfNtusn82ihvliVUwg5KrPGVf6GH94ZJpEZBen6EC4qYTHA1dXhW0JIex8txzv//c8lhzXIi/BFxOH9jGbQhZsRalTIBZZ8KkGyZAxeRQvXkFF1TWz/Hm46jNYUnjPbt3JxIkT7f6dSj8qfJJyVvBxgaIlblOyjtysNHWN9fjjqWi7glJfW3/S0Hlj2XnA8PhKT9w6g3Qx3XiXhvuxQsuT1proxBKI/AaZqY1Xz5muvY8G8XkRRCaHsfQsRAFDH/tZPbcYuHotOG0FRIqB4HR3wNVoIPLtz8ycTguu+jpEigE218vd1YCr5m+HpHMvEI9u4LTXwNWaLjl0iPwGAmIpeHx1VeCqTJdPs1/vweweQPO3HC24NhOhnTphwoQnfv6QSY2ICbkNmdSA4h87oaLaiYfn5diIEd4att2erOwJXbPUHp953p6orQVSUVWRAXBT8c/dJ5L9xhzaJGp71GR/wFP8P5V2z10NSC9T93QM2xUg8fHxT+zU9ijeU4naHon8CjFJXFzc8/kn+dN06q9QgF98SYSo2Xen2NjYZy5sR6f+4nLSK5Iam2PH/x87a1YN/t5sBgAAAABJRU5ErkJggg==){height=\"20px\" width=\"117px\"}](https://colab.research.google.com/github/PytorchLightning/lightning-tutorials/blob/publication/.notebooks/lightning_examples/text-transformers.ipynb)\n", "\n", "Give us a \u2b50 [on Github](https://www.github.com/Lightning-AI/lightning/)\n", "| Check out [the documentation](https://lightning.ai/docs/)\n", "| Join us [on Discord](https://discord.com/invite/tfXFetEZxv)"]}, {"cell_type": "markdown", "id": "28a6e9a5", "metadata": {"papermill": {"duration": 0.002272, "end_time": "2025-05-01T12:07:39.428830", "exception": false, "start_time": "2025-05-01T12:07:39.426558", "status": "completed"}, "tags": []}, "source": ["## Setup\n", "This notebook requires some packages besides pytorch-lightning."]}, {"cell_type": "code", "execution_count": 1, "id": "873da06b", "metadata": {"colab": {}, "colab_type": "code", "execution": {"iopub.execute_input": "2025-05-01T12:07:39.434712Z", "iopub.status.busy": "2025-05-01T12:07:39.434482Z", "iopub.status.idle": "2025-05-01T12:07:40.867143Z", "shell.execute_reply": "2025-05-01T12:07:40.866103Z"}, "id": "LfrJLKPFyhsK", "lines_to_next_cell": 0, "papermill": {"duration": 1.437184, "end_time": "2025-05-01T12:07:40.868569", "exception": false, "start_time": "2025-05-01T12:07:39.431385", "status": "completed"}, "tags": []}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\r\n", "\u001b[0m"]}], "source": ["! pip install --quiet \"torchtext\" \"numpy <2.0\" \"numpy <3.0\" \"datasets <=2.21\" \"torch ==2.1.*\" \"scikit-learn\" \"scipy\" \"torchvision ==0.16.*\" \"pytorch-lightning >=2.0,<2.6\" \"matplotlib\" \"torchmetrics >=1.0, <1.2\" \"torchmetrics>=1.0, <1.8\" \"transformers <4.50\" \"torch>=1.8.1, <2.8\""]}, {"cell_type": "code", "execution_count": 2, "id": "81c50b79", "metadata": {"execution": {"iopub.execute_input": "2025-05-01T12:07:40.875296Z", "iopub.status.busy": "2025-05-01T12:07:40.874931Z", "iopub.status.idle": "2025-05-01T12:07:45.422107Z", "shell.execute_reply": "2025-05-01T12:07:45.420847Z"}, "papermill": {"duration": 4.552094, "end_time": "2025-05-01T12:07:45.423562", "exception": false, "start_time": "2025-05-01T12:07:40.871468", "status": "completed"}, "tags": []}, "outputs": [], "source": ["from collections import defaultdict\n", "from datetime import datetime\n", "from typing import Optional\n", "\n", "import datasets\n", "import pytorch_lightning as pl\n", "import torch\n", "from torch.utils.data import DataLoader\n", "from transformers import (\n", " AdamW,\n", " AutoConfig,\n", " AutoModelForSequenceClassification,\n", " AutoTokenizer,\n", " get_linear_schedule_with_warmup,\n", ")"]}, {"cell_type": "markdown", "id": "412c20ff", "metadata": {"papermill": {"duration": 0.002458, "end_time": "2025-05-01T12:07:45.428681", "exception": false, "start_time": "2025-05-01T12:07:45.426223", "status": "completed"}, "tags": []}, "source": ["## Training BERT with Lightning"]}, {"cell_type": "markdown", "id": "832fbc97", "metadata": {"lines_to_next_cell": 2, "papermill": {"duration": 0.002401, "end_time": "2025-05-01T12:07:45.434145", "exception": false, "start_time": "2025-05-01T12:07:45.431744", "status": "completed"}, "tags": []}, "source": ["### Lightning DataModule for GLUE"]}, {"cell_type": "code", "execution_count": 3, "id": "1ed613d4", "metadata": {"execution": {"iopub.execute_input": "2025-05-01T12:07:45.440330Z", "iopub.status.busy": "2025-05-01T12:07:45.439921Z", "iopub.status.idle": "2025-05-01T12:07:45.452511Z", "shell.execute_reply": "2025-05-01T12:07:45.451530Z"}, "papermill": {"duration": 0.017025, "end_time": "2025-05-01T12:07:45.453554", "exception": false, "start_time": "2025-05-01T12:07:45.436529", "status": "completed"}, "tags": []}, "outputs": [], "source": ["class GLUEDataModule(pl.LightningDataModule):\n", " task_text_field_map = {\n", " \"cola\": [\"sentence\"],\n", " \"sst2\": [\"sentence\"],\n", " \"mrpc\": [\"sentence1\", \"sentence2\"],\n", " \"qqp\": [\"question1\", \"question2\"],\n", " \"stsb\": [\"sentence1\", \"sentence2\"],\n", " \"mnli\": [\"premise\", \"hypothesis\"],\n", " \"qnli\": [\"question\", \"sentence\"],\n", " \"rte\": [\"sentence1\", \"sentence2\"],\n", " \"wnli\": [\"sentence1\", \"sentence2\"],\n", " \"ax\": [\"premise\", \"hypothesis\"],\n", " }\n", "\n", " glue_task_num_labels = {\n", " \"cola\": 2,\n", " \"sst2\": 2,\n", " \"mrpc\": 2,\n", " \"qqp\": 2,\n", " \"stsb\": 1,\n", " \"mnli\": 3,\n", " \"qnli\": 2,\n", " \"rte\": 2,\n", " \"wnli\": 2,\n", " \"ax\": 3,\n", " }\n", "\n", " loader_columns = [\n", " \"datasets_idx\",\n", " \"input_ids\",\n", " \"token_type_ids\",\n", " \"attention_mask\",\n", " \"start_positions\",\n", " \"end_positions\",\n", " \"labels\",\n", " ]\n", "\n", " def __init__(\n", " self,\n", " model_name_or_path: str,\n", " task_name: str = \"mrpc\",\n", " max_seq_length: int = 128,\n", " train_batch_size: int = 32,\n", " eval_batch_size: int = 32,\n", " **kwargs,\n", " ):\n", " super().__init__()\n", " self.model_name_or_path = model_name_or_path\n", " self.task_name = task_name\n", " self.max_seq_length = max_seq_length\n", " self.train_batch_size = train_batch_size\n", " self.eval_batch_size = eval_batch_size\n", "\n", " self.text_fields = self.task_text_field_map[task_name]\n", " self.num_labels = self.glue_task_num_labels[task_name]\n", " self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", "\n", " def setup(self, stage=None):\n", " self.dataset = datasets.load_dataset(\"glue\", self.task_name)\n", "\n", " for split in self.dataset.keys():\n", " self.dataset[split] = self.dataset[split].map(\n", " self.convert_to_features,\n", " batched=True,\n", " remove_columns=[\"label\"],\n", " )\n", " self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]\n", " self.dataset[split].set_format(type=\"torch\", columns=self.columns)\n", "\n", " self.eval_splits = [x for x in self.dataset.keys() if \"validation\" in x]\n", "\n", " def prepare_data(self):\n", " datasets.load_dataset(\"glue\", self.task_name)\n", " AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)\n", "\n", " def train_dataloader(self):\n", " return DataLoader(self.dataset[\"train\"], batch_size=self.train_batch_size, shuffle=True)\n", "\n", " def val_dataloader(self):\n", " if len(self.eval_splits) == 1:\n", " return DataLoader(self.dataset[\"validation\"], batch_size=self.eval_batch_size)\n", " elif len(self.eval_splits) > 1:\n", " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", "\n", " def test_dataloader(self):\n", " if len(self.eval_splits) == 1:\n", " return DataLoader(self.dataset[\"test\"], batch_size=self.eval_batch_size)\n", " elif len(self.eval_splits) > 1:\n", " return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]\n", "\n", " def convert_to_features(self, example_batch, indices=None):\n", " # Either encode single sentence or sentence pairs\n", " if len(self.text_fields) > 1:\n", " texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))\n", " else:\n", " texts_or_text_pairs = example_batch[self.text_fields[0]]\n", "\n", " # Tokenize the text/text pairs\n", " features = self.tokenizer.batch_encode_plus(\n", " texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True\n", " )\n", "\n", " # Rename label to labels to make it easier to pass to model forward\n", " features[\"labels\"] = example_batch[\"label\"]\n", "\n", " return features"]}, {"cell_type": "markdown", "id": "fcdcda13", "metadata": {"papermill": {"duration": 0.002665, "end_time": "2025-05-01T12:07:45.482707", "exception": false, "start_time": "2025-05-01T12:07:45.480042", "status": "completed"}, "tags": []}, "source": ["**You could use this datamodule with standalone PyTorch if you wanted...**"]}, {"cell_type": "code", "execution_count": 4, "id": "7239844c", "metadata": {"execution": {"iopub.execute_input": "2025-05-01T12:07:45.488424Z", "iopub.status.busy": "2025-05-01T12:07:45.488249Z", "iopub.status.idle": "2025-05-01T12:07:51.561516Z", "shell.execute_reply": "2025-05-01T12:07:51.560518Z"}, "papermill": {"duration": 6.077954, "end_time": "2025-05-01T12:07:51.563125", "exception": false, "start_time": "2025-05-01T12:07:45.485171", "status": "completed"}, "tags": []}, "outputs": [{"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "d02b8e6793a547c5a9dd5447a9d7f93e", "version_major": 2, "version_minor": 0}, "text/plain": ["tokenizer_config.json: 0%| | 0.00/48.0 [00:00 1:\n", " preds = torch.argmax(logits, axis=1)\n", " elif self.hparams.num_labels == 1:\n", " preds = logits.squeeze()\n", "\n", " labels = batch[\"labels\"]\n", "\n", " self.outputs[dataloader_idx].append({\"loss\": val_loss, \"preds\": preds, \"labels\": labels})\n", "\n", " def on_validation_epoch_end(self):\n", " if self.hparams.task_name == \"mnli\":\n", " for i, outputs in self.outputs.items():\n", " # matched or mismatched\n", " split = self.hparams.eval_splits[i].split(\"_\")[-1]\n", " preds = torch.cat([x[\"preds\"] for x in outputs]).detach().cpu().numpy()\n", " labels = torch.cat([x[\"labels\"] for x in outputs]).detach().cpu().numpy()\n", " loss = torch.stack([x[\"loss\"] for x in outputs]).mean()\n", " self.log(f\"val_loss_{split}\", loss, prog_bar=True)\n", " split_metrics = {\n", " f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()\n", " }\n", " self.log_dict(split_metrics, prog_bar=True)\n", " return loss\n", "\n", " flat_outputs = []\n", " for lst in self.outputs.values():\n", " flat_outputs.extend(lst)\n", "\n", " preds = torch.cat([x[\"preds\"] for x in flat_outputs]).detach().cpu().numpy()\n", " labels = torch.cat([x[\"labels\"] for x in flat_outputs]).detach().cpu().numpy()\n", " loss = torch.stack([x[\"loss\"] for x in flat_outputs]).mean()\n", " self.log(\"val_loss\", loss, prog_bar=True)\n", " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", " self.outputs.clear()\n", "\n", " def configure_optimizers(self):\n", " \"\"\"Prepare optimizer and schedule (linear warmup and decay).\"\"\"\n", " model = self.model\n", " no_decay = [\"bias\", \"LayerNorm.weight\"]\n", " optimizer_grouped_parameters = [\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n", " \"weight_decay\": self.hparams.weight_decay,\n", " },\n", " {\n", " \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n", " \"weight_decay\": 0.0,\n", " },\n", " ]\n", " optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)\n", "\n", " scheduler = get_linear_schedule_with_warmup(\n", " optimizer,\n", " num_warmup_steps=self.hparams.warmup_steps,\n", " num_training_steps=self.trainer.estimated_stepping_batches,\n", " )\n", " scheduler = {\"scheduler\": scheduler, \"interval\": \"step\", \"frequency\": 1}\n", " return [optimizer], [scheduler]"]}, {"cell_type": "markdown", "id": "80af51d0", "metadata": {"papermill": {"duration": 0.008229, "end_time": "2025-05-01T12:07:51.633055", "exception": false, "start_time": "2025-05-01T12:07:51.624826", "status": "completed"}, "tags": []}, "source": ["## Training"]}, {"cell_type": "markdown", "id": "fc128487", "metadata": {"papermill": {"duration": 0.008294, "end_time": "2025-05-01T12:07:51.649960", "exception": false, "start_time": "2025-05-01T12:07:51.641666", "status": "completed"}, "tags": []}, "source": ["### CoLA\n", "\n", "See an interactive view of the\n", "CoLA dataset in [NLP Viewer](https://huggingface.co/nlp/viewer/?dataset=glue&config=cola)"]}, {"cell_type": "code", "execution_count": 6, "id": "84aba97f", "metadata": {"execution": {"iopub.execute_input": "2025-05-01T12:07:51.666513Z", "iopub.status.busy": "2025-05-01T12:07:51.666292Z", "iopub.status.idle": "2025-05-01T12:08:50.297246Z", "shell.execute_reply": "2025-05-01T12:08:50.296287Z"}, "papermill": {"duration": 58.640646, "end_time": "2025-05-01T12:08:50.299139", "exception": false, "start_time": "2025-05-01T12:07:51.658493", "status": "completed"}, "tags": []}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["Seed set to 42\n"]}, {"data": {"application/vnd.jupyter.widget-view+json": {"model_id": "8e43cfcc0d9b4fc6bb1fe31f9c9f9cc4", "version_major": 2, "version_minor": 0}, "text/plain": ["tokenizer_config.json: 0%| | 0.00/25.0 [00:00