torchgeo.trainers#

TorchGeo trainers.

class torchgeo.trainers.BYOLTask(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]#

Bases: BaseTask

BYOL: Bootstrap Your Own Latent.

Reference implementation:

If you use this trainer in your research, please cite the following paper:

monitor = 'train_loss'#

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, lr=0.001, patience=10)[source]#

Initialize a new BYOLTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

Changed in version 0.4: backbone_name was renamed to backbone. Changed backbone support from torchvision.models to timm.

Changed in version 0.5: backbone, learning_rate, and learning_rate_schedule_patience were renamed to model, lr, and patience.

configure_models()[source]#

Initialize the model.

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Raises:

AssertionError – If channel dimensions are incorrect.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

class torchgeo.trainers.BaseTask[source]#

Bases: LightningModule, ABC

Abstract base class for all TorchGeo trainers.

Added in version 0.5.

ignore: Sequence[str] | str | None = 'weights'#

Parameters to ignore when saving hyperparameters.

model: Any#

Model to train.

monitor = 'val_loss'#

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'min'#

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__()[source]#

Initialize a new BaseTask instance.

Parameters:

ignore – Arguments to skip when saving hyperparameters.

abstractmethod configure_models()[source]#

Initialize the model.

configure_losses()[source]#

Initialize the loss criterion.

configure_metrics()[source]#

Initialize the performance metrics.

configure_optimizers()[source]#

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

Optimizer | Sequence[Optimizer] | tuple[Sequence[Optimizer], Sequence[LRScheduler | ReduceLROnPlateau | LRSchedulerConfig]] | OptimizerConfig | OptimizerLRSchedulerConfig | Sequence[OptimizerConfig] | Sequence[OptimizerLRSchedulerConfig] | None

forward(*args, **kwargs)[source]#

Forward pass of the model.

Parameters:
  • args (Any) – Arguments to pass to model.

  • kwargs (Any) – Keyword arguments to pass to model.

Returns:

Output of the model.

Return type:

Any

class torchgeo.trainers.ChangeDetectionTask(model='unet', backbone='resnet50', weights=None, in_channels=3, task='binary', num_classes=None, num_labels=None, labels=None, num_filters=3, pos_weight=None, loss='bce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Bases: ClassificationMixin, BaseTask

Change Detection. Supports binary, multiclass, and multilabel change detection.

Added in version 0.8.

__init__(model='unet', backbone='resnet50', weights=None, in_channels=3, task='binary', num_classes=None, num_labels=None, labels=None, num_filters=3, pos_weight=None, loss='bce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Initialize a new ChangeDetectionTask instance.

Parameters:
  • model (Literal['unet', 'deeplabv3+', 'fcn', 'upernet', 'segformer', 'dpt', 'fcsiamdiff', 'fcsiamconc', 'changevit', 'btc']) – Name of the model to use.

  • backbone (str) – Name of the timm or smp backbone to use.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights.

  • in_channels (int) – Number of channels per image.

  • task (Literal['binary', 'multiclass', 'multilabel']) – One of ‘binary’, ‘multiclass’, or ‘multilabel’.

  • num_classes (int | None) – Number of prediction classes (only for task='multiclass').

  • num_labels (int | None) – Number of prediction labels (only for task='multilabel').

  • labels (list[str] | None) – List of class names.

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • pos_weight (Tensor | None) – A weight of positive examples and used with ‘bce’ loss.

  • loss (Literal['ce', 'bce', 'jaccard', 'focal', 'dice']) – Name of the loss function, currently supports ‘ce’, ‘bce’, ‘jaccard’, ‘focal’, and ‘dice’ loss.

  • class_weights (Tensor | Sequence[float] | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • ignore_index (int | None) – Optional integer class index to ignore in the loss and metrics.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the segmentation head.

Added in version 0.9: The labels parameter.

configure_models()[source]#

Initialize the model.

training_step(batch, batch_idx)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx)[source]#

Compute the validation loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

test_step(batch, batch_idx)[source]#

Compute the test loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the predicted class probabilities.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.ClassificationMixin(*args, **kwargs)[source]#

Bases: LightningModule

Mix-in for classification-based tasks.

Added in version 0.9.

configure_losses()[source]#

Initialize the loss criterion.

configure_metrics()[source]#

Initialize the performance metrics.

Includes the following metrics:

See https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers for more details.

Higher values are better for all metrics. All metrics report multiple versions:

  • Overall (micro): Sum statistics over all labels

  • Average (macro): Calculate statistics for each label and average them

  • Classwise (none): Calculates statistic for each label and applies no reduction

on_train_epoch_end()[source]#

Log train metrics.

on_validation_epoch_end()[source]#

Log validation metrics.

on_test_epoch_end()[source]#

Log test metrics.

class torchgeo.trainers.ClassificationTask(model='resnet50', weights=None, in_channels=3, task='multiclass', num_classes=None, num_labels=None, labels=None, pos_weight=None, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False)[source]#

Bases: ClassificationMixin, BaseTask

Image classification.

__init__(model='resnet50', weights=None, in_channels=3, task='multiclass', num_classes=None, num_labels=None, labels=None, pos_weight=None, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False)[source]#

Initialize a new ClassificationTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • task (Literal['binary', 'multiclass', 'multilabel']) – One of ‘binary’, ‘multiclass’, or ‘multilabel’.

  • num_classes (int | None) – Number of prediction classes (only for task='multiclass').

  • num_labels (int | None) – Number of prediction labels (only for task='multilabel').

  • labels (list[str] | None) – List of class names.

  • pos_weight (Tensor | None) – A weight of positive examples and used with ‘bce’ loss.

  • loss (Literal['ce', 'bce', 'jaccard', 'focal', 'dice']) – One of ‘ce’, ‘bce’, ‘jaccard’, ‘focal’, or ‘dice’.

  • class_weights (Tensor | Sequence[float] | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • ignore_index (int | None) – Optional integer class index to ignore in the loss and metrics.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to linear probe the classifier head.

Added in version 0.9: The labels, pos_weight, and ignore_index parameters and dice loss support.

Added in version 0.7: The task and num_labels parameters.

Added in version 0.5: The class_weights and freeze_backbone parameters.

Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

Changed in version 0.4: classification_model was renamed to model.

configure_models()[source]#

Initialize the model.

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the validation loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the test loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the predicted class probabilities.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.IOBenchTask[source]#

Bases: BaseTask

I/O benchmarking.

Added in version 0.6.

configure_models()[source]#

No-op.

configure_optimizers()[source]#

Initialize the optimizer.

Returns:

Optimizer.

Return type:

Optimizer | Sequence[Optimizer] | tuple[Sequence[Optimizer], Sequence[LRScheduler | ReduceLROnPlateau | LRSchedulerConfig]] | OptimizerConfig | OptimizerLRSchedulerConfig | Sequence[OptimizerConfig] | Sequence[OptimizerLRSchedulerConfig] | None

training_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Zero.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

class torchgeo.trainers.InstanceSegmentationTask(model='mask-rcnn', backbone='resnet50', weights=None, weights_backbone=None, in_channels=3, num_classes=91, lr=0.001, patience=10, freeze_backbone=False)[source]#

Bases: BaseTask

Instance Segmentation.

Added in version 0.7.

ignore = ('weights', 'weights_backbone')#

Parameters to ignore when saving hyperparameters.

monitor = 'val_segm_map'#

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'max'#

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__(model='mask-rcnn', backbone='resnet50', weights=None, weights_backbone=None, in_channels=3, num_classes=91, lr=0.001, patience=10, freeze_backbone=False)[source]#

Initialize a new InstanceSegmentationTask instance.

Note that we disable the internal normalize+resize transform of the MaskRCNN model. Please ensure your images are appropriately resized before passing them to the model.

Parameters:
  • model (str) – Name of the model to use.

  • backbone (str) – Name of the backbone to use.

  • weights (WeightsEnum | None) – Initial model weights.

  • weights_backbone (WeightsEnum | None) – Initial backbone weights.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes (including the background).

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.

Added in version 0.9: The weights_backbone parameter.

configure_models()[source]#

Initialize the model.

Raises:

ValueError – If model or backbone are invalid.

configure_metrics()[source]#

Initialize the performance metrics.

  • MeanAveragePrecision: Mean average precision (mAP) and mean average recall (mAR). Precision is the number of true positives divided by the number of true positives + false positives. Recall is the number of true positives divived by the number of true positives + false negatives. Uses ‘macro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the validation metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the test metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the predicted masks.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted masks.

Return type:

list[dict[str, Tensor]]

class torchgeo.trainers.MAETask(model='vit_base_patch32_224', weights=None, in_channels=3, transform=None, decoder_dim=512, lr=0.00015, decoder_num_heads=8, decoder_depth=1, weight_decay=0.05, mask_ratio=0.75, size=224, norm_pix_loss=True, warmup_epochs=40)[source]#

Bases: BaseTask

MAE: Masked Autoencoder for self-supervised learning.

Reference implementations:

If you use this code for your research, please cite the original paper:

ignore = ('transform', 'weights')#

Parameters to ignore when saving hyperparameters.

__init__(model='vit_base_patch32_224', weights=None, in_channels=3, transform=None, decoder_dim=512, lr=0.00015, decoder_num_heads=8, decoder_depth=1, weight_decay=0.05, mask_ratio=0.75, size=224, norm_pix_loss=True, warmup_epochs=40)[source]#

Initialize the MAE task.

Parameters:
  • model (str) – The ViT architecture to use for the encoder. Must be compatible with timm’s create_model function.

  • weights (WeightsEnum | str | bool | None) – Pretrained weights to initialize the encoder with. Can be a timm WeightsEnum or a string identifier for a timm weight, True to use default pretrained weights, or None for random initialization.

  • in_channels (int) – Number of input channels in the images. Must match the in_chans argument of the ViT model.

  • transform (Module | None) – Optional transform to apply to the input images. If None, a default MAE augmentation will be used.

  • decoder_dim (int) – The embedding dimension of the MAE decoder. Typically 512 is a good choice for ViT-Base encoders.

  • lr (float) – Should typically be set to 1.5e-4 * batch_size / 256.

  • decoder_num_heads (int) – Number of attention heads in the MAE decoder.

  • decoder_depth (int) – Number of layers in the MAE decoder. Typically 1-4 layers is sufficient for good performance.

  • weight_decay (float) – Weight decay for the AdamW optimizer.

  • mask_ratio (float) – The ratio of tokens to mask during training. Typically 0.75 is a good choice.

  • size (int) – The input image size (height and width) after augmentation. Must match the input size expected by the ViT model.

  • norm_pix_loss (bool) – If True, normalize each target patch to zero mean and unit variance before computing MSE. Recommended by the original MAE paper.

  • warmup_epochs (int) – Number of linear warmup epochs before cosine annealing.

configure_losses()[source]#

Initialize the loss criterion.

configure_models()[source]#

Initialize the model.

configure_optimizers()[source]#

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

Optimizer | Sequence[Optimizer] | tuple[Sequence[Optimizer], Sequence[LRScheduler | ReduceLROnPlateau | LRSchedulerConfig]] | OptimizerConfig | OptimizerLRSchedulerConfig | Sequence[OptimizerConfig] | Sequence[OptimizerLRSchedulerConfig] | None

forward(images, idx_keep, idx_mask)[source]#

Forward pass through MAE encoder and decoder.

Parameters:
  • images (Tensor) – The input images, with shape (B, in_channels, H, W).

  • idx_keep (Tensor) – The indices of the tokens that were kept (not masked), with shape (B, N_keep).

  • idx_mask (Tensor) – The indices of the tokens that were masked, with shape (B, N_mask).

Returns:

The predicted pixel values for the masked tokens, with shape (B, N_mask,

patch_size*patch_size*in_channels).

Return type:

Tensor

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

class torchgeo.trainers.MoCoTask(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]#

Bases: BaseTask

MoCo: Momentum Contrast.

Reference implementations:

If you use this trainer in your research, please cite the following papers:

Added in version 0.5.

ignore = ('weights', 'augmentation1', 'augmentation2')#

Parameters to ignore when saving hyperparameters.

monitor = 'train_loss'#

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, version=3, layers=3, hidden_dim=4096, output_dim=256, lr=9.6, weight_decay=1e-06, momentum=0.9, schedule=[120, 160], temperature=1, memory_bank_size=0, moco_momentum=0.99, gather_distributed=False, size=224, grayscale_weights=None, augmentation1=None, augmentation2=None)[source]#

Initialize a new MoCoTask instance.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • version (int) – Version of MoCo, 1–3.

  • layers (int) – Number of layers in projection head (not used in v1, 2 for v1/2, 3 for v3).

  • hidden_dim (int) – Number of hidden dimensions in projection head (not used in v1, 2048 for v2, 4096 for v3).

  • output_dim (int) – Number of output dimensions in projection head (not used in v1, 128 for v2, 256 for v3).

  • lr (float) – Learning rate (0.03 x batch_size / 256 for v1/2, 0.6 x batch_size / 256 for v3).

  • weight_decay (float) – Weight decay coefficient (1e-4 for v1/2, 1e-6 for v3).

  • momentum (float) – Momentum of SGD solver (v1/2 only).

  • schedule (Sequence[int]) – Epochs at which to drop lr by 10x (v1/2 only).

  • temperature (float) – Temperature used in InfoNCE loss (0.07 for v1/2, 1 for v3).

  • memory_bank_size (int) – Size of memory bank (65536 for v1/2, 0 for v3).

  • moco_momentum (float) – MoCo momentum of updating key encoder (0.999 for v1/2, 0.99 for v3)

  • gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).

  • size (int) – Size of patch to crop.

  • grayscale_weights (Tensor | None) – Weight vector for grayscale computation, see RandomGrayscale. Only used when augmentations=None. Defaults to average of all bands.

  • augmentation1 (Module | None) – Data augmentation for 1st branch. Defaults to MoCo augmentation.

  • augmentation2 (Module | None) – Data augmentation for 2nd branch. Defaults to MoCo augmentation.

Raises:

AssertionError – If an invalid version of MoCo is requested.

Warns:

UserWarning – If hyperparameters do not match MoCo version requested.

configure_models()[source]#

Initialize the model.

configure_losses()[source]#

Initialize the loss criterion.

configure_optimizers()[source]#

Initialize the optimizer and learning rate scheduler.

Returns:

Optimizer and learning rate scheduler.

Return type:

Optimizer | Sequence[Optimizer] | tuple[Sequence[Optimizer], Sequence[LRScheduler | ReduceLROnPlateau | LRSchedulerConfig]] | OptimizerConfig | OptimizerLRSchedulerConfig | Sequence[OptimizerConfig] | Sequence[OptimizerLRSchedulerConfig] | None

forward(x)[source]#

Forward pass of the model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output of the model and backbone

Return type:

tuple[Tensor, Tensor]

forward_momentum(x)[source]#

Forward pass of the momentum model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output from the momentum model.

Return type:

Tensor

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

class torchgeo.trainers.ObjectDetectionTask(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]#

Bases: BaseTask

Object detection.

Added in version 0.4.

monitor = 'val_map'#

Performance metric to monitor in learning rate scheduler and callbacks.

mode = 'max'#

Whether the goal is to minimize or maximize the performance metric to monitor.

__init__(model='faster-rcnn', backbone='resnet50', weights=None, in_channels=3, num_classes=1000, trainable_layers=3, lr=0.001, patience=10, freeze_backbone=False)[source]#

Initialize a new ObjectDetectionTask instance.

Note that we disable the internal normalize+resize transform of the detection models. Please ensure your images are appropriately resized before passing them to the model.

Parameters:
  • model (str) – Name of the torchvision model to use. One of ‘faster-rcnn’, ‘fcos’, or ‘retinanet’.

  • backbone (str) – Name of the torchvision backbone to use. One of ‘resnet18’, ‘resnet34’, ‘resnet50’, ‘resnet101’, ‘resnet152’, ‘resnext50_32x4d’, ‘resnext101_32x8d’, ‘wide_resnet50_2’, or ‘wide_resnet101_2’.

  • weights (WeightsEnum | None) – Initial model weights.

  • in_channels (int) – Number of input channels to model.

  • num_classes (int) – Number of prediction classes (including the background).

  • trainable_layers (int) – Number of trainable layers.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the detection head.

Changed in version 0.4: detection_model was renamed to model.

Added in version 0.5: The freeze_backbone parameter.

Changed in version 0.5: pretrained, learning_rate, and learning_rate_schedule_patience were renamed to weights, lr, and patience.

configure_models()[source]#

Initialize the model.

Raises:

ValueError – If model or backbone are invalid.

configure_metrics()[source]#

Initialize the performance metrics.

  • MeanAveragePrecision: Mean average precision (mAP) and mean average recall (mAR). Precision is the number of true positives divided by the number of true positives + false positives. Recall is the number of true positives divived by the number of true positives + false negatives. Uses ‘macro’ averaging. Higher values are better.

Note

  • ‘Micro’ averaging suits overall performance evaluation but may not reflect minority class accuracy.

  • ‘Macro’ averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes.

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the validation metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the test metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the predicted bounding boxes.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

list[dict[str, Tensor]]

class torchgeo.trainers.PixelwiseRegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Bases: RegressionTask

LightningModule for pixelwise regression of images.

Added in version 0.5.

configure_models()[source]#

Initialize the model.

class torchgeo.trainers.RegressionTask(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Bases: BaseTask

Regression.

__init__(model='resnet50', backbone='resnet50', weights=None, in_channels=3, num_outputs=1, num_filters=3, loss='mse', lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Initialize a new RegressionTask instance.

Parameters:
  • model (str) – Name of the timm or smp model to use.

  • backbone (str) – Name of the timm or smp backbone to use. Only applicable to PixelwiseRegressionTask.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • num_outputs (int) – Number of prediction outputs.

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • loss (str) – One of ‘mse’ or ‘mae’.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to linear probe the regression head. Does not support FCN models.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the regression head. Does not support FCN models. Only applicable to PixelwiseRegressionTask.

Changed in version 0.4: Change regression model support from torchvision.models to timm

Added in version 0.5: The freeze_backbone and freeze_decoder parameters.

Changed in version 0.5: learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

configure_models()[source]#

Initialize the model.

configure_losses()[source]#

Initialize the loss criterion.

Raises:

ValueError – If loss is invalid.

configure_metrics()[source]#

Initialize the performance metrics.

  • MeanSquaredError: The average of the squared differences between the predicted and actual values (MSE) and its square root (RMSE). Lower values are better.

  • MeanAbsoluteError: The average of the absolute differences between the predicted and actual values (MAE). Lower values are better.

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the validation loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the test loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the predicted regression values.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Output predicted probabilities.

Return type:

Tensor

class torchgeo.trainers.SemanticSegmentationTask(model='unet', backbone='resnet50', weights=None, in_channels=3, task='multiclass', num_classes=None, num_labels=None, labels=None, num_filters=3, pos_weight=None, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Bases: ClassificationMixin, BaseTask

Semantic Segmentation.

__init__(model='unet', backbone='resnet50', weights=None, in_channels=3, task='multiclass', num_classes=None, num_labels=None, labels=None, num_filters=3, pos_weight=None, loss='ce', class_weights=None, ignore_index=None, lr=0.001, patience=10, freeze_backbone=False, freeze_decoder=False)[source]#

Initialize a new SemanticSegmentationTask instance.

Parameters:
  • model (Literal['unet', 'deeplabv3+', 'fcn', 'upernet', 'segformer', 'dpt']) – Name of the smp model to use.

  • backbone (str) – Name of the timm or smp backbone to use.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. FCN model does not support pretrained weights.

  • in_channels (int) – Number of input channels to model.

  • task (Literal['binary', 'multiclass', 'multilabel']) – One of ‘binary’, ‘multiclass’, or ‘multilabel’.

  • num_classes (int | None) – Number of prediction classes (only for task='multiclass').

  • num_labels (int | None) – Number of prediction labels (only for task='multilabel').

  • labels (list[str] | None) – List of class names.

  • num_filters (int) – Number of filters. Only applicable when model=’fcn’.

  • pos_weight (Tensor | None) – A weight of positive examples and used with ‘bce’ loss.

  • loss (Literal['ce', 'bce', 'jaccard', 'focal', 'dice']) – Name of the loss function, currently supports ‘ce’, ‘bce’, ‘jaccard’, ‘focal’, and ‘dice’ loss.

  • class_weights (Tensor | Sequence[float] | None) – Optional rescaling weight given to each class and used with ‘ce’ loss.

  • ignore_index (int | None) – Optional integer class index to ignore in the loss and metrics.

  • lr (float) – Learning rate for optimizer.

  • patience (int) – Patience for learning rate scheduler.

  • freeze_backbone (bool) – Freeze the backbone network to fine-tune the decoder and segmentation head.

  • freeze_decoder (bool) – Freeze the decoder network to linear probe the segmentation head.

Added in version 0.9: The labels and pos_weight parameters and dice loss support.

Added in version 0.8: Time series, DPT, Segformer, and UPerNet support.

Added in version 0.7: The task and num_labels parameters.

Changed in version 0.6: The ignore_index parameter now works for jaccard loss.

Added in version 0.5: The class_weights, freeze_backbone, and freeze_decoder parameters.

Changed in version 0.5: The weights parameter now supports WeightEnums and checkpoint paths. learning_rate and learning_rate_schedule_patience were renamed to lr and patience.

Changed in version 0.4: segmentation_model, encoder_name, and encoder_weights were renamed to model, backbone, and weights.

Changed in version 0.3: ignore_zeros was renamed to ignore_index.

forward(x)[source]#

Forward pass of the model.

Parameters:

x (Tensor) – Input tensor of shape (B, C, H, W) or (B, T, C, H, W).

Returns:

Output tensor of shape (B, num_classes, H, W).

Return type:

Tensor

configure_models()[source]#

Initialize the model.

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the validation loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the test loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the predicted class probabilities.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

Dictionary with ‘probabilities’, ‘bounds’, and ‘transform’ keys.

Return type:

dict[str, Tensor | None]

Changed in version 0.9: Changed return type from Tensor to dict with probabilities, bounds, and transform keys.

class torchgeo.trainers.SimCLRTask(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, momentum=0.9, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]#

Bases: BaseTask

SimCLR: a simple framework for contrastive learning of visual representations.

Reference implementation:

If you use this trainer in your research, please cite the following papers:

Added in version 0.5.

ignore = ('weights', 'augmentations')#

Parameters to ignore when saving hyperparameters.

monitor = 'train_loss'#

Performance metric to monitor in learning rate scheduler and callbacks.

__init__(model='resnet50', weights=None, in_channels=3, version=2, layers=3, hidden_dim=None, output_dim=None, lr=4.8, momentum=0.9, weight_decay=0.0001, temperature=0.07, memory_bank_size=64000, gather_distributed=False, size=224, grayscale_weights=None, augmentations=None)[source]#

Initialize a new SimCLRTask instance.

Added in version 0.6: The momentum parameter.

Parameters:
  • model (str) – Name of the timm model to use.

  • weights (WeightsEnum | str | bool | None) – Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict.

  • in_channels (int) – Number of input channels to model.

  • version (int) – Version of SimCLR, 1–2.

  • layers (int) – Number of layers in projection head (2 for v1, 3+ for v2).

  • hidden_dim (int | None) – Number of hidden dimensions in projection head (defaults to output dimension of model).

  • output_dim (int | None) – Number of output dimensions in projection head (defaults to output dimension of model).

  • lr (float) – Learning rate (0.3 x batch_size / 256 is recommended).

  • momentum (float) – Momentum factor.

  • weight_decay (float) – Weight decay coefficient (1e-6 for v1, 1e-4 for v2).

  • temperature (float) – Temperature used in NT-Xent loss.

  • memory_bank_size (int) – Size of memory bank (0 for v1, 64K for v2).

  • gather_distributed (bool) – Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0).

  • size (int) – Size of patch to crop.

  • grayscale_weights (Tensor | None) – Weight vector for grayscale computation, see RandomGrayscale. Only used when augmentations=None. Defaults to average of all bands.

  • augmentations (Module | None) – Data augmentation. Defaults to SimCLR augmentation.

Raises:

AssertionError – If an invalid version of SimCLR is requested.

Warns:

UserWarning – If hyperparameters do not match SimCLR version requested.

configure_models()[source]#

Initialize the model.

configure_losses()[source]#

Initialize the loss criterion.

forward(x)[source]#

Forward pass of the model.

Parameters:

x (Tensor) – Mini-batch of images.

Returns:

Output of the model and backbone.

Return type:

tuple[Tensor, Tensor]

training_step(batch, batch_idx, dataloader_idx=0)[source]#

Compute the training loss and additional metrics.

Parameters:
  • batch (dict[str, Any]) – The output of your DataLoader.

  • batch_idx (int) – Integer displaying index of this batch.

  • dataloader_idx (int) – Index of the current dataloader.

Returns:

The loss tensor.

Raises:

AssertionError – If channel dimensions are incorrect.

Return type:

Tensor

validation_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

test_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

predict_step(batch, batch_idx, dataloader_idx=0)[source]#

No-op, does nothing.

configure_optimizers()[source]#

Initialize the optimizer and learning rate scheduler.

Changed in version 0.6: Changed from Adam to LARS optimizer.

Returns:

Optimizer and learning rate scheduler.

Return type:

Optimizer | Sequence[Optimizer] | tuple[Sequence[Optimizer], Sequence[LRScheduler | ReduceLROnPlateau | LRSchedulerConfig]] | OptimizerConfig | OptimizerLRSchedulerConfig | Sequence[OptimizerConfig] | Sequence[OptimizerLRSchedulerConfig] | None