# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Trainers for semantic segmentation."""
import os
from collections.abc import Sequence
from typing import Any, Literal
import kornia.augmentation as K
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from einops import rearrange
from matplotlib.figure import Figure
from torch import Tensor
from torchvision.models._api import WeightsEnum
from ..datamodules import BaseDataModule
from ..datasets import RGBBandsMissingError, unbind_samples
from ..models import FCN, get_weight
from . import utils
from .base import BaseTask
from .mixins import ClassificationMixin
[docs]
class SemanticSegmentationTask(ClassificationMixin, BaseTask):
"""Semantic Segmentation."""
[docs]
def __init__(
self,
model: Literal[
'unet', 'deeplabv3+', 'fcn', 'upernet', 'segformer', 'dpt'
] = 'unet',
backbone: str = 'resnet50',
weights: WeightsEnum | str | bool | None = None,
in_channels: int = 3,
task: Literal['binary', 'multiclass', 'multilabel'] = 'multiclass',
num_classes: int | None = None,
num_labels: int | None = None,
labels: list[str] | None = None,
num_filters: int = 3,
pos_weight: Tensor | None = None,
loss: Literal['ce', 'bce', 'jaccard', 'focal', 'dice'] = 'ce',
class_weights: Tensor | Sequence[float] | None = None,
ignore_index: int | None = None,
lr: float = 1e-3,
patience: int = 10,
freeze_backbone: bool = False,
freeze_decoder: bool = False,
) -> None:
"""Initialize a new SemanticSegmentationTask instance.
Args:
model: Name of the
`smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use.
backbone: Name of the `timm
<https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp
<https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False or
None for random weights, or the path to a saved model state dict. FCN
model does not support pretrained weights.
in_channels: Number of input channels to model.
task: One of 'binary', 'multiclass', or 'multilabel'.
num_classes: Number of prediction classes (only for ``task='multiclass'``).
num_labels: Number of prediction labels (only for ``task='multilabel'``).
labels: List of class names.
num_filters: Number of filters. Only applicable when model='fcn'.
pos_weight: A weight of positive examples and used with 'bce' loss.
loss: Name of the loss function, currently supports
'ce', 'bce', 'jaccard', 'focal', and 'dice' loss.
class_weights: Optional rescaling weight given to each
class and used with 'ce' loss.
ignore_index: Optional integer class index to ignore in the loss and
metrics.
lr: Learning rate for optimizer.
patience: Patience for learning rate scheduler.
freeze_backbone: Freeze the backbone network to fine-tune the
decoder and segmentation head.
freeze_decoder: Freeze the decoder network to linear probe
the segmentation head.
.. versionadded:: 0.9
The *labels* and *pos_weight* parameters and dice loss support.
.. versionadded:: 0.8
Time series, DPT, Segformer, and UPerNet support.
.. versionadded:: 0.7
The *task* and *num_labels* parameters.
.. versionchanged:: 0.6
The *ignore_index* parameter now works for jaccard loss.
.. versionadded:: 0.5
The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters.
.. versionchanged:: 0.5
The *weights* parameter now supports WeightEnums and checkpoint paths.
*learning_rate* and *learning_rate_schedule_patience* were renamed to
*lr* and *patience*.
.. versionchanged:: 0.4
*segmentation_model*, *encoder_name*, and *encoder_weights*
were renamed to *model*, *backbone*, and *weights*.
.. versionchanged:: 0.3
*ignore_zeros* was renamed to *ignore_index*.
"""
self.weights = weights
super().__init__()
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of the model.
Args:
x: Input tensor of shape (B, C, H, W) or (B, T, C, H, W).
Returns:
Output tensor of shape (B, num_classes, H, W).
"""
if x.ndim == 5:
x = rearrange(x, 'b t c h w -> b (t c) h w')
x = self.model(x)
return x
[docs]
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
The loss tensor.
"""
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
y_hat = self(x).squeeze(1)
self.train_metrics(y_hat, y)
if self.hparams['loss'] == 'bce':
y = y.float()
loss: Tensor = self.criterion(y_hat, y)
self.log('train_loss', loss, batch_size=batch_size)
return loss
[docs]
def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
y_hat = self(x).squeeze(1)
self.val_metrics(y_hat, y)
if self.hparams['loss'] == 'bce':
y = y.float()
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, batch_size=batch_size)
if (
batch_idx < 10
and hasattr(self.trainer, 'datamodule')
and isinstance(self.trainer.datamodule, BaseDataModule)
and self.logger
and hasattr(self.logger, 'experiment')
and hasattr(self.logger.experiment, 'add_figure')
):
datamodule = self.trainer.datamodule
if batch['image'].ndim == 5:
_, T, C, _, _ = batch['image'].shape
batch['image'] = rearrange(batch['image'], 'b t c h w -> b (t c) h w')
aug = K.AugmentationSequential(
K.Denormalize(datamodule.mean, datamodule.std),
data_keys=None,
keepdim=True,
)
batch = aug(batch)
batch['image'] = rearrange(
batch['image'], 'b (t c) h w -> b t c h w', t=T, c=C
)
else:
aug = K.AugmentationSequential(
K.Denormalize(datamodule.mean, datamodule.std),
data_keys=None,
keepdim=True,
)
batch = aug(batch)
match self.hparams['task']:
case 'binary' | 'multilabel':
batch['prediction'] = (y_hat.sigmoid() >= 0.5).long()
case 'multiclass':
batch['prediction'] = y_hat.argmax(dim=1)
for key in ['image', 'mask', 'prediction']:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig: Figure | None = None
try:
fig = datamodule.plot(sample)
except RGBBandsMissingError:
pass
if fig:
summary_writer = self.logger.experiment
summary_writer.add_figure(
f'image/{batch_idx}', fig, global_step=self.global_step
) # type: ignore[call-non-callable]
plt.close()
[docs]
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
y_hat = self(x).squeeze(1)
self.test_metrics(y_hat, y)
if self.hparams['loss'] == 'bce':
y = y.float()
loss = self.criterion(y_hat, y)
self.log('test_loss', loss, batch_size=batch_size)
[docs]
def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> dict[str, Tensor]:
"""Compute the predicted class probabilities.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
Dictionary with 'probabilities', 'bounds', and 'transform' keys.
.. versionchanged:: 0.9
Changed return type from Tensor to dict with probabilities, bounds,
and transform keys.
"""
x = batch['image']
y_hat: Tensor = self(x)
match self.hparams['task']:
case 'binary' | 'multilabel':
y_hat = y_hat.sigmoid()
case 'multiclass':
y_hat = y_hat.softmax(dim=1)
return {
'probabilities': y_hat,
'bounds': batch.get('bounds'),
'transform': batch.get('transform'),
}