Finetune Transformers Models with PyTorch Lightning¶
Author: Lightning.ai
License: CC BY-SA
Generated: 2024-07-26T13:13:57.667341
This notebook will use HuggingFace’s datasets
library to get data, which will be wrapped in a LightningDataModule
. Then, we write a class to perform text classification on any dataset from the GLUE Benchmark. (We just show CoLA and MRPC due to constraint on compute/disk)
Give us a ⭐ on Github | Check out the documentation | Join us on Discord
Setup¶
This notebook requires some packages besides pytorch-lightning.
[1]:
! pip install --quiet "scipy" "datasets" "transformers" "torch>=1.8.1, <2.5" "torchmetrics>=1.0, <1.5" "torchmetrics >=1.0, <1.2" "torch ==2.1.*" "numpy <2.0" "scikit-learn" "torchvision ==0.16.*" "matplotlib" "pytorch-lightning >=2.0,<2.4" "torchtext"
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
[2]:
from collections import defaultdict
from datetime import datetime
from typing import Optional
import datasets
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from transformers import (
AdamW,
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
Training BERT with Lightning¶
Lightning DataModule for GLUE¶
[3]:
class GLUEDataModule(pl.LightningDataModule):
task_text_field_map = {
"cola": ["sentence"],
"sst2": ["sentence"],
"mrpc": ["sentence1", "sentence2"],
"qqp": ["question1", "question2"],
"stsb": ["sentence1", "sentence2"],
"mnli": ["premise", "hypothesis"],
"qnli": ["question", "sentence"],
"rte": ["sentence1", "sentence2"],
"wnli": ["sentence1", "sentence2"],
"ax": ["premise", "hypothesis"],
}
glue_task_num_labels = {
"cola": 2,
"sst2": 2,
"mrpc": 2,
"qqp": 2,
"stsb": 1,
"mnli": 3,
"qnli": 2,
"rte": 2,
"wnli": 2,
"ax": 3,
}
loader_columns = [
"datasets_idx",
"input_ids",
"token_type_ids",
"attention_mask",
"start_positions",
"end_positions",
"labels",
]
def __init__(
self,
model_name_or_path: str,
task_name: str = "mrpc",
max_seq_length: int = 128,
train_batch_size: int = 32,
eval_batch_size: int = 32,
**kwargs,
):
super().__init__()
self.model_name_or_path = model_name_or_path
self.task_name = task_name
self.max_seq_length = max_seq_length
self.train_batch_size = train_batch_size
self.eval_batch_size = eval_batch_size
self.text_fields = self.task_text_field_map[task_name]
self.num_labels = self.glue_task_num_labels[task_name]
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def setup(self, stage=None):
self.dataset = datasets.load_dataset("glue", self.task_name)
for split in self.dataset.keys():
self.dataset[split] = self.dataset[split].map(
self.convert_to_features,
batched=True,
remove_columns=["label"],
)
self.columns = [c for c in self.dataset[split].column_names if c in self.loader_columns]
self.dataset[split].set_format(type="torch", columns=self.columns)
self.eval_splits = [x for x in self.dataset.keys() if "validation" in x]
def prepare_data(self):
datasets.load_dataset("glue", self.task_name)
AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
def train_dataloader(self):
return DataLoader(self.dataset["train"], batch_size=self.train_batch_size, shuffle=True)
def val_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["validation"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]
def test_dataloader(self):
if len(self.eval_splits) == 1:
return DataLoader(self.dataset["test"], batch_size=self.eval_batch_size)
elif len(self.eval_splits) > 1:
return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]
def convert_to_features(self, example_batch, indices=None):
# Either encode single sentence or sentence pairs
if len(self.text_fields) > 1:
texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
else:
texts_or_text_pairs = example_batch[self.text_fields[0]]
# Tokenize the text/text pairs
features = self.tokenizer.batch_encode_plus(
texts_or_text_pairs, max_length=self.max_seq_length, pad_to_max_length=True, truncation=True
)
# Rename label to labels to make it easier to pass to model forward
features["labels"] = example_batch["label"]
return features
You could use this datamodule with standalone PyTorch if you wanted…
[4]:
dm = GLUEDataModule("distilbert-base-uncased")
dm.prepare_data()
dm.setup("fit")
next(iter(dm.train_dataloader()))
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2888: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
[4]:
{'input_ids': tensor([[ 101, 2035, 2216, ..., 0, 0, 0],
[ 101, 1000, 1996, ..., 0, 0, 0],
[ 101, 2021, 24040, ..., 0, 0, 0],
...,
[ 101, 3463, 2013, ..., 0, 0, 0],
[ 101, 2027, 2056, ..., 0, 0, 0],
[ 101, 4584, 2012, ..., 0, 0, 0]]),
'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
...,
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]]),
'labels': tensor([1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1,
0, 0, 1, 1, 0, 0, 1, 1])}
Transformer LightningModule¶
[5]:
class GLUETransformer(pl.LightningModule):
def __init__(
self,
model_name_or_path: str,
num_labels: int,
task_name: str,
learning_rate: float = 2e-5,
adam_epsilon: float = 1e-8,
warmup_steps: int = 0,
weight_decay: float = 0.0,
train_batch_size: int = 32,
eval_batch_size: int = 32,
eval_splits: Optional[list] = None,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
self.metric = datasets.load_metric(
"glue",
self.hparams.task_name,
experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S"),
trust_remote_code=True,
)
self.outputs = defaultdict(list)
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs[0]
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(**batch)
val_loss, logits = outputs[:2]
if self.hparams.num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif self.hparams.num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
self.outputs[dataloader_idx].append({"loss": val_loss, "preds": preds, "labels": labels})
def on_validation_epoch_end(self):
if self.hparams.task_name == "mnli":
for i, outputs in self.outputs.items():
# matched or mismatched
split = self.hparams.eval_splits[i].split("_")[-1]
preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in outputs]).mean()
self.log(f"val_loss_{split}", loss, prog_bar=True)
split_metrics = {
f"{k}_{split}": v for k, v in self.metric.compute(predictions=preds, references=labels).items()
}
self.log_dict(split_metrics, prog_bar=True)
return loss
flat_outputs = []
for lst in self.outputs.values():
flat_outputs.extend(lst)
preds = torch.cat([x["preds"] for x in flat_outputs]).detach().cpu().numpy()
labels = torch.cat([x["labels"] for x in flat_outputs]).detach().cpu().numpy()
loss = torch.stack([x["loss"] for x in flat_outputs]).mean()
self.log("val_loss", loss, prog_bar=True)
self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)
self.outputs.clear()
def configure_optimizers(self):
"""Prepare optimizer and schedule (linear warmup and decay)."""
model = self.model
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": self.hparams.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=self.hparams.warmup_steps,
num_training_steps=self.trainer.estimated_stepping_batches,
)
scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}
return [optimizer], [scheduler]
Training¶
CoLA¶
See an interactive view of the CoLA dataset in NLP Viewer
[6]:
pl.seed_everything(42)
dm = GLUEDataModule(model_name_or_path="albert-base-v2", task_name="cola")
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="albert-base-v2",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
)
trainer = pl.Trainer(
max_epochs=1,
accelerator="auto",
devices=1,
)
trainer.fit(model, datamodule=dm)
Seed set to 42
Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/tmp/ipykernel_816/3306335502.py:22: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
self.metric = datasets.load_metric(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /__w/13/s/lightning_logs
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2888: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
Loading `train_dataloader` to estimate number of stepping batches.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
| Name | Type | Params | Mode
-----------------------------------------------------------------
0 | model | AlbertForSequenceClassification | 11.7 M | eval
-----------------------------------------------------------------
11.7 M Trainable params
0 Non-trainable params
11.7 M Total params
46.740 Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=1` reached.
MRPC¶
See an interactive view of the MRPC dataset in NLP Viewer
[7]:
pl.seed_everything(42)
dm = GLUEDataModule(
model_name_or_path="distilbert-base-cased",
task_name="mrpc",
)
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="distilbert-base-cased",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
)
trainer = pl.Trainer(
max_epochs=3,
accelerator="auto",
devices=1,
)
trainer.fit(model, datamodule=dm)
Seed set to 42
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2888: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2888: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
warnings.warn(
Loading `train_dataloader` to estimate number of stepping batches.
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
| Name | Type | Params | Mode
---------------------------------------------------------------------
0 | model | DistilBertForSequenceClassification | 65.8 M | eval
---------------------------------------------------------------------
65.8 M Trainable params
0 Non-trainable params
65.8 M Total params
263.132 Total estimated model params size (MB)
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
`Trainer.fit` stopped: `max_epochs=3` reached.
MNLI¶
The MNLI dataset is huge, so we aren’t going to bother trying to train on it here.
We will skip over training and go straight to validation.
See an interactive view of the MRPC dataset in NLP Viewer
[8]:
dm = GLUEDataModule(
model_name_or_path="distilbert-base-cased",
task_name="mnli",
)
dm.setup("fit")
model = GLUETransformer(
model_name_or_path="distilbert-base-cased",
num_labels=dm.num_labels,
eval_splits=dm.eval_splits,
task_name=dm.task_name,
)
trainer = pl.Trainer(
max_epochs=3,
accelerator="auto",
devices=1,
)
trainer.validate(model, dm)
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2888: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2888: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).
warnings.warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Validate metric DataLoader 0 DataLoader 1
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
accuracy_matched 0.31767702102661133 0.31767702102661133
accuracy_mismatched 0.3187550902366638 0.3187550902366638
val_loss_matched 1.1031880378723145 1.1031880378723145
val_loss_mismatched 1.1027535200119019 1.1027535200119019
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
[8]:
[{'val_loss_matched': 1.1031880378723145,
'accuracy_matched': 0.31767702102661133,
'val_loss_mismatched': 1.1027535200119019,
'accuracy_mismatched': 0.3187550902366638},
{'val_loss_matched': 1.1031880378723145,
'accuracy_matched': 0.31767702102661133,
'val_loss_mismatched': 1.1027535200119019,
'accuracy_mismatched': 0.3187550902366638}]
Congratulations - Time to Join the Community!¶
Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!
Star Lightning on GitHub¶
The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.
Join our Discord!¶
The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general
channel
Contributions !¶
The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.
You can also contribute your own notebooks with useful examples !