Source code for torchgeo.trainers.mixins
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Mix-ins for trainers."""
import segmentation_models_pytorch as smp
import torch
from lightning.pytorch import LightningModule
from torch import Tensor, nn
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import (
Accuracy,
F1Score,
JaccardIndex,
Precision,
Recall,
)
from torchmetrics.wrappers import ClasswiseWrapper
[docs]
class ClassificationMixin(LightningModule):
"""Mix-in for classification-based tasks.
.. versionadded:: 0.9
"""
[docs]
def configure_losses(self) -> None:
"""Initialize the loss criterion."""
ignore_index: int | None = self.hparams['ignore_index']
class_weights = self.hparams['class_weights']
if class_weights is not None and not isinstance(class_weights, Tensor):
class_weights = torch.tensor(class_weights, dtype=torch.float32)
match self.hparams['loss']:
case 'ce':
ignore_value = -1000 if ignore_index is None else ignore_index
self.criterion: nn.Module = nn.CrossEntropyLoss(
ignore_index=ignore_value, weight=class_weights
)
case 'bce':
self.criterion = nn.BCEWithLogitsLoss(
pos_weight=self.hparams['pos_weight']
)
case 'jaccard':
# JaccardLoss requires a list of classes to use instead of a class
# index to ignore.
if self.hparams['task'] == 'multiclass' and ignore_index is not None:
classes = [
i
for i in range(self.hparams['num_classes'])
if i != ignore_index
]
self.criterion = smp.losses.JaccardLoss(
mode=self.hparams['task'], classes=classes
)
else:
self.criterion = smp.losses.JaccardLoss(mode=self.hparams['task'])
case 'focal':
self.criterion = smp.losses.FocalLoss(
mode=self.hparams['task'],
ignore_index=ignore_index,
normalized=True,
)
case 'dice':
self.criterion = smp.losses.DiceLoss(
mode=self.hparams['task'], ignore_index=ignore_index
)
[docs]
def configure_metrics(self) -> None:
r"""Initialize the performance metrics.
Includes the following metrics:
* :class:`~torchmetrics.Accuracy`: :math:`\frac{TP + TN}{P + N}`
* :class:`~torchmetrics.Precision`: :math:`\frac{TP}{TP + FP}`
* :class:`~torchmetrics.Recall`: :math:`\frac{TP}{P}`
* :class:`~torchmetrics.F1Score`: :math:`\frac{2 TP}{2 TP + FP + FN}`
* :class:`~torchmetrics.JaccardIndex`: :math:`\frac{TP}{TP + FN + FP}`
See https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers
for more details.
Higher values are better for all metrics. All metrics report multiple versions:
* Overall (micro): Sum statistics over all labels
* Average (macro): Calculate statistics for each label and average them
* Classwise (none): Calculates statistic for each label and applies no reduction
"""
kwargs = {
'task': self.hparams['task'],
'num_classes': self.hparams['num_classes'],
'num_labels': self.hparams['num_labels'],
'ignore_index': self.hparams['ignore_index'],
}
metrics_dict: dict[str, Metric | MetricCollection] = {}
for metric in [Accuracy, Precision, Recall, F1Score, JaccardIndex]:
metrics_dict |= {
f'Overall{metric.__name__}': metric(average='micro', **kwargs),
f'Average{metric.__name__}': metric(average='macro', **kwargs),
}
if self.hparams['task'] != 'binary':
metrics_dict[metric.__name__] = ClasswiseWrapper(
metric(average='none', **kwargs),
labels=self.hparams['labels'],
prefix=f'Classwise{metric.__name__}_',
)
metrics = MetricCollection(metrics_dict)
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
[docs]
def on_train_epoch_end(self) -> None:
"""Log train metrics."""
self.log_dict(self.train_metrics.compute())
self.train_metrics.reset()
[docs]
def on_validation_epoch_end(self) -> None:
"""Log validation metrics."""
self.log_dict(self.val_metrics.compute())
self.val_metrics.reset()
[docs]
def on_test_epoch_end(self) -> None:
"""Log test metrics."""
self.log_dict(self.test_metrics.compute())
self.test_metrics.reset()