• Docs >
  • TPU training with PyTorch Lightning
Shortcuts

TPU training with PyTorch Lightning

  • Author: Lightning.ai

  • License: CC BY-SA

  • Generated: 2024-07-23T19:27:26.548935

In this notebook, we’ll train a model on TPUs. Updating one Trainer flag is all you need for that. The most up to documentation related to TPU training can be found here.


Open in Open In Colab

Give us a ⭐ on Github | Check out the documentation | Join us on Discord

Setup

This notebook requires some packages besides pytorch-lightning.

[ ]:
! pip install --quiet "pytorch-lightning>=2.0, <2.1.0" "matplotlib" "numpy <2.0" "torchvision" "torchmetrics>=1.0, <1.3" "torch>=1.8.1, <2.1.0"

Install Colab TPU compatible PyTorch/TPU wheels and dependencies

[ ]:
! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

import pytorch_lightning as pl
[ ]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from torchvision import transforms

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST

BATCH_SIZE = 1024

Defining The MNISTDataModule

Below we define MNISTDataModule. You can learn more about datamodules in docs.

[ ]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)

Defining the LitModel

Below, we define the model LitMNIST.

[ ]:
class LitModel(pl.LightningModule):
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes),
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer

TPU Training

Lightning supports training on a single TPU core or 8 TPU cores.

The Trainer parameter devices defines how many TPU cores to train on (1 or 8) / Single TPU core to train on [1] along with accelerator=’tpu’.

For Single TPU training, Just pass the TPU core ID [1-8] in a list. Setting devices=[5] will train on TPU core ID 5.

Train on TPU core ID 5 with devices=[5].

[ ]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="tpu",
    devices=[5],
)
# Train
trainer.fit(model, dm)

Train on single TPU core with devices=1.

[ ]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="tpu",
    devices=1,
)
# Train
trainer.fit(model, dm)

Train on 8 TPU cores with accelerator='tpu' and devices=8. You might have to restart the notebook to run it on 8 TPU cores after training on single TPU core.

[ ]:
# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.dims, dm.num_classes)
# Init trainer
trainer = pl.Trainer(
    max_epochs=3,
    accelerator="tpu",
    devices=8,
)
# Train
trainer.fit(model, dm)

Congratulations - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the Lightning movement, you can do so in the following ways!

Star Lightning on GitHub

The easiest way to help our community is just by starring the GitHub repos! This helps raise awareness of the cool tools we’re building.

Join our Discord!

The best way to keep up to date on the latest advancements is to join our community! Make sure to introduce yourself and share your interests in #general channel

Contributions !

The best way to contribute to our community is to become a code contributor! At any time you can go to Lightning or Bolt GitHub Issues page and filter for “good first issue”.

Great thanks from the entire Pytorch Lightning Team for your interest !

Pytorch Lightning