[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Custom Trainers#
Written by: Caleb Robinson
In this tutorial, we demonstrate how to extend a TorchGeo “trainer class”. In TorchGeo there exist several trainer classes that are pre-made PyTorch Lightning Modules designed to allow for the easy training of models on semantic segmentation, classification, change detection, etc. tasks using TorchGeo’s prebuilt DataModules. While the trainers aim to provide sensible defaults and customization options for common tasks, they will not be able to cover all situations (e.g., researchers will likely want to implement and use their own architectures, loss functions, optimizers, and/or metrics in the training routine). If you run into such a situation, then you can simply extend the trainer class you are interested in, and write custom logic to override the default functionality.
This tutorial shows how to do exactly this to customize a learning rate schedule, logging, and model checkpointing for a semantic segmentation task using the LandCover.ai dataset.
It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.
Setup#
As always, we install TorchGeo.
[ ]:
%pip install torchgeo
Imports#
Next, we import TorchGeo and any other libraries we need.
[ ]:
import os
import tempfile
from collections.abc import Sequence
from typing import Any
import lightning
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.callback import Callback
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics import MetricCollection
from torchmetrics.classification import (
Accuracy,
FBetaScore,
JaccardIndex,
Precision,
Recall,
)
from torchgeo.datamodules import LandCoverAI100DataModule
from torchgeo.trainers import SemanticSegmentationTask
Custom SemanticSegmentationTask#
Now, we create a CustomSemanticSegmentationTask class that inhierits from SemanticSegmentationTask and that overrides a few methods:
__init__: We add two new parameterstmaxandeta_minto control the learning rate schedulerconfigure_optimizers: We use theCosineAnnealingLRlearning rate scheduler instead of the defaultReduceLROnPlateauconfigure_metrics: We add a “MeanIoU” metric (what we will use to evaluate the model’s performance) and a variety of other classification metricsconfigure_callbacks: We demonstrate how to stackModelCheckpointcallbacks to save the best checkpoint as well as periodic checkpointson_train_epoch_start: We log the learning rate at the start of each epoch so we can easily see how it decays over a training run
Overall these demonstrate how to customize the training routine to investigate specific research questions (e.g., the effect of the scheduler on test performance).
[ ]:
class CustomSemanticSegmentationTask(SemanticSegmentationTask):
# any keywords we add here between *args and **kwargs will be found in self.hparams
def __init__(
self, *args: Any, tmax: int = 50, eta_min: float = 1e-6, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs) # pass args and kwargs to the parent class
def configure_optimizers(
self,
) -> 'lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig':
"""Initialize the optimizer and learning rate scheduler.
Returns:
Optimizer and learning rate scheduler.
"""
tmax: int = self.hparams['tmax']
eta_min: float = self.hparams['eta_min']
optimizer = AdamW(self.parameters(), lr=self.hparams['lr'])
scheduler = CosineAnnealingLR(optimizer, T_max=tmax, eta_min=eta_min)
return {
'optimizer': optimizer,
'lr_scheduler': {'scheduler': scheduler, 'monitor': self.monitor},
}
def configure_metrics(self) -> None:
"""Initialize the performance metrics."""
num_classes: int = self.hparams['num_classes']
self.train_metrics = MetricCollection(
{
'OverallAccuracy': Accuracy(
task='multiclass', num_classes=num_classes, average='micro'
),
'OverallPrecision': Precision(
task='multiclass', num_classes=num_classes, average='micro'
),
'OverallRecall': Recall(
task='multiclass', num_classes=num_classes, average='micro'
),
'OverallF1Score': FBetaScore(
task='multiclass',
num_classes=num_classes,
beta=1.0,
average='micro',
),
'MeanIoU': JaccardIndex(
num_classes=num_classes, task='multiclass', average='macro'
),
},
prefix='train_',
)
self.val_metrics = self.train_metrics.clone(prefix='val_')
self.test_metrics = self.train_metrics.clone(prefix='test_')
def configure_callbacks(self) -> Sequence[Callback] | Callback:
"""Initialize callbacks for saving the best and latest models.
Returns:
List of callbacks to apply.
"""
return [
ModelCheckpoint(every_n_epochs=50, save_top_k=-1, save_last=True),
ModelCheckpoint(monitor=self.monitor, mode=self.mode, save_top_k=5),
]
def on_train_epoch_start(self) -> None:
"""Log the learning rate at the start of each training epoch."""
optimizers = self.optimizers()
if isinstance(optimizers, list):
lr = optimizers[0].param_groups[0]['lr']
else:
lr = optimizers.param_groups[0]['lr']
self.logger.experiment.add_scalar('lr', lr, self.current_epoch)
Train model#
The remainder of the turial is straightforward and follows the typical PyTorch Lightning training routine. We instantiate a DataModule for the LandCover.AI 100 dataset (a small version of the LandCover.AI dataset for notebook testing), instantiate a CustomSemanticSegmentationTask with a U-Net and ResNet-18 backbone, then train the model using a Lightning trainer.
The following variables can be modified to control training.
[ ]:
batch_size = 32
num_workers = 8
max_epochs = 50
fast_dev_run = False
use_pretrained = True
[ ]:
root = os.path.join(tempfile.gettempdir(), 'segmentation')
dm = LandCoverAI100DataModule(
root=root, batch_size=batch_size, num_workers=num_workers, download=True
)
[ ]:
task = CustomSemanticSegmentationTask(
model='unet',
backbone='resnet18',
weights=use_pretrained,
in_channels=3,
num_classes=6,
loss='ce',
lr=1e-3,
tmax=50,
)
[ ]:
# validate that the task's hyperparameters are as expected
task.hparams
[ ]:
trainer = pl.Trainer(
fast_dev_run=fast_dev_run, log_every_n_steps=1, min_epochs=1, max_epochs=max_epochs
)
[ ]:
trainer.fit(task, dm)
Test model#
Finally, we test the model (optionally loading from a previously saved checkpoint).
[ ]:
# You can load directly from a saved checkpoint with `.load_from_checkpoint(...)`
# Note that you can also just call `trainer.test(task, dm)` if you've already trained
# the model in the current notebook session.
# task = CustomSemanticSegmentationTask.load_from_checkpoint(
# os.path.join('lightning_logs', 'version_0', 'checkpoints', 'epoch=0-step=1.ckpt')
# )
trainer.test(task, dm)