Shortcuts

Source code for torchgeo.trainers.classification

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Classification tasks."""

import os
from typing import Any, Dict, cast

import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
from segmentation_models_pytorch.losses import FocalLoss, JaccardLoss
from torch import Tensor
from torch.nn.modules import Conv2d, Linear
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchmetrics import Accuracy, FBetaScore, JaccardIndex, MetricCollection

from ..datasets.utils import unbind_samples
from . import utils

# https://github.com/pytorch/pytorch/issues/60979
# https://github.com/pytorch/pytorch/pull/61045
Conv2d.__module__ = "nn.Conv2d"
Linear.__module__ = "nn.Linear"


class ClassificationTask(pl.LightningModule):
    """LightningModule for image classification."""

[docs] def config_model(self) -> None: """Configures the model based on kwargs parameters passed to the constructor.""" in_channels = self.hyperparams["in_channels"] classification_model = self.hyperparams["classification_model"] imagenet_pretrained = False custom_pretrained = False if self.hyperparams["weights"] and not os.path.exists( self.hyperparams["weights"] ): if self.hyperparams["weights"] not in ["imagenet", "random"]: raise ValueError( f"Weight type '{self.hyperparams['weights']}' is not valid." ) else: imagenet_pretrained = self.hyperparams["weights"] == "imagenet" custom_pretrained = False else: custom_pretrained = True # Create the model valid_models = timm.list_models(pretrained=True) if classification_model in valid_models: self.model = timm.create_model( classification_model, num_classes=self.hyperparams["num_classes"], in_chans=in_channels, pretrained=imagenet_pretrained, ) else: raise ValueError( f"Model type '{classification_model}' is not a valid timm model." ) if custom_pretrained: name, state_dict = utils.extract_encoder(self.hyperparams["weights"]) if self.hyperparams["classification_model"] != name: raise ValueError( f"Trying to load {name} weights into a " f"{self.hyperparams['classification_model']}" ) self.model = utils.load_state_dict(self.model, state_dict)
[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" self.config_model() if self.hyperparams["loss"] == "ce": self.loss: nn.Module = nn.CrossEntropyLoss() elif self.hyperparams["loss"] == "jaccard": self.loss = JaccardLoss(mode="multiclass") elif self.hyperparams["loss"] == "focal": self.loss = FocalLoss(mode="multiclass", normalized=True) else: raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.")
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: classification_model: Name of the classification model use loss: Name of the loss function weights: Either "random", "imagenet_only", "imagenet_and_random", or "random_rgb" """ super().__init__() # Creates `self.hparams` from kwargs self.save_hyperparameters() # type: ignore[operator] self.hyperparams = cast(Dict[str, Any], self.hparams) self.config_task() self.train_metrics = MetricCollection( { "OverallAccuracy": Accuracy( num_classes=self.hyperparams["num_classes"], average="micro" ), "AverageAccuracy": Accuracy( num_classes=self.hyperparams["num_classes"], average="macro" ), "JaccardIndex": JaccardIndex( num_classes=self.hyperparams["num_classes"] ), "F1Score": FBetaScore( num_classes=self.hyperparams["num_classes"], beta=1.0, average="micro", ), }, prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_")
[docs] def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass of the model. Args: x: input image Returns: prediction """ return self.model(*args, **kwargs)
[docs] def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader Returns: training loss """ batch = args[0] x = batch["image"] y = batch["label"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) # by default, the train step logs every `log_every_n_steps` steps where # `log_every_n_steps` is a parameter to the `Trainer` object self.log("train_loss", loss, on_step=True, on_epoch=False) self.train_metrics(y_hat_hard, y) return cast(Tensor, loss)
[docs] def training_epoch_end(self, outputs: Any) -> None: """Logs epoch-level training metrics. Args: outputs: list of items returned by training_step """ self.log_dict(self.train_metrics.compute()) self.train_metrics.reset()
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss and log example predictions. Args: batch: the output of your DataLoader batch_idx: the index of this batch """ batch = args[0] batch_idx = args[1] x = batch["image"] y = batch["label"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) if batch_idx < 10: try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] batch["prediction"] = y_hat_hard for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) except AttributeError: pass
[docs] def validation_epoch_end(self, outputs: Any) -> None: """Logs epoch level validation metrics. Args: outputs: list of items returned by validation_step """ self.log_dict(self.val_metrics.compute()) self.val_metrics.reset()
[docs] def test_step(self, *args: Any, **kwargs: Any) -> None: """Compute test loss. Args: batch: the output of your DataLoader """ batch = args[0] x = batch["image"] y = batch["label"] y_hat = self.forward(x) y_hat_hard = y_hat.argmax(dim=1) loss = self.loss(y_hat, y) # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) self.test_metrics(y_hat_hard, y)
[docs] def test_epoch_end(self, outputs: Any) -> None: """Logs epoch level test metrics. Args: outputs: list of items returned by test_step """ self.log_dict(self.test_metrics.compute()) self.test_metrics.reset()
[docs] def configure_optimizers(self) -> Dict[str, Any]: """Initialize the optimizer and learning rate scheduler. Returns: a "lr dict" according to the pytorch lightning documentation -- https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers """ optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.hyperparams["learning_rate"] ) return { "optimizer": optimizer, "lr_scheduler": { "scheduler": ReduceLROnPlateau( optimizer, patience=self.hyperparams["learning_rate_schedule_patience"], ), "monitor": "val_loss", }, }
class MultiLabelClassificationTask(ClassificationTask): """LightningModule for multi-label image classification."""
[docs] def config_task(self) -> None: """Configures the task based on kwargs parameters passed to the constructor.""" self.config_model() if self.hyperparams["loss"] == "bce": self.loss = nn.BCEWithLogitsLoss() else: raise ValueError(f"Loss type '{self.hyperparams['loss']}' is not valid.")
[docs] def __init__(self, **kwargs: Any) -> None: """Initialize the LightningModule with a model and loss function. Keyword Args: classification_model: Name of the classification model use loss: Name of the loss function weights: Either "random", "imagenet_only", "imagenet_and_random", or "random_rgb" """ super().__init__(**kwargs) # Creates `self.hparams` from kwargs self.save_hyperparameters() # type: ignore[operator] self.hyperparams = cast(Dict[str, Any], self.hparams) self.config_task() self.train_metrics = MetricCollection( { "OverallAccuracy": Accuracy( num_classes=self.hyperparams["num_classes"], average="micro", multiclass=False, ), "AverageAccuracy": Accuracy( num_classes=self.hyperparams["num_classes"], average="macro", multiclass=False, ), "F1Score": FBetaScore( num_classes=self.hyperparams["num_classes"], beta=1.0, average="micro", multiclass=False, ), }, prefix="train_", ) self.val_metrics = self.train_metrics.clone(prefix="val_") self.test_metrics = self.train_metrics.clone(prefix="test_")
[docs] def training_step(self, *args: Any, **kwargs: Any) -> Tensor: """Compute and return the training loss. Args: batch: the output of your DataLoader Returns: training loss """ batch = args[0] x = batch["image"] y = batch["label"] y_hat = self.forward(x) y_hat_hard = torch.softmax(y_hat, dim=-1) loss = self.loss(y_hat, y.to(torch.float)) # by default, the train step logs every `log_every_n_steps` steps where # `log_every_n_steps` is a parameter to the `Trainer` object self.log("train_loss", loss, on_step=True, on_epoch=False) self.train_metrics(y_hat_hard, y) return cast(Tensor, loss)
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> None: """Compute validation loss and log example predictions. Args: batch: the output of your DataLoader batch_idx: the index of this batch """ batch = args[0] batch_idx = args[1] x = batch["image"] y = batch["label"] y_hat = self.forward(x) y_hat_hard = torch.softmax(y_hat, dim=-1) loss = self.loss(y_hat, y.to(torch.float)) self.log("val_loss", loss, on_step=False, on_epoch=True) self.val_metrics(y_hat_hard, y) if batch_idx < 10: try: datamodule = self.trainer.datamodule # type: ignore[attr-defined] batch["prediction"] = y_hat_hard for key in ["image", "label", "prediction"]: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig = datamodule.plot(sample) summary_writer = self.logger.experiment summary_writer.add_figure( f"image/{batch_idx}", fig, global_step=self.global_step ) except AttributeError: pass
[docs] def test_step(self, *args: Any, **kwargs: Any) -> None: """Compute test loss. Args: batch: the output of your DataLoader """ batch = args[0] x = batch["image"] y = batch["label"] y_hat = self.forward(x) y_hat_hard = torch.softmax(y_hat, dim=-1) loss = self.loss(y_hat, y.to(torch.float)) # by default, the test and validation steps only log per *epoch* self.log("test_loss", loss, on_step=False, on_epoch=True) self.test_metrics(y_hat_hard, y)

© Copyright 2021, Microsoft Corporation. Revision af389759.

Built with Sphinx using a theme provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources