Source code for torchgeo.models.btc

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
#
# Based on the original code: https://github.com/blaz-r/BTC-change-detection

"""Be The Change (BTC) change detection model implementation."""

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from segmentation_models_pytorch.decoders.upernet.decoder import UPerNetDecoder
from torch import Tensor
from torch.nn.modules import Module
from torchvision.models.feature_extraction import create_feature_extractor

from torchgeo.models.swin import (
    Swin_B_Weights,
    Swin_S_Weights,
    Swin_T_Weights,
    swin_b,
    swin_s,
    swin_t,
)


[docs] class BTC(Module): """Be The Change (BTC) change detection model. If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2507.03367 .. versionadded:: 0.8 """
[docs] def __init__( self, backbone: str, backbone_pretrained: bool = False, classes: int = 1 ) -> None: """Initialise BTC model. Args: backbone: backbone type (either swin_tiny, swin_small or swin_base). classes: number of classes. backbone_pretrained: whether the cityscapes pretrained swin is used. """ super().__init__() self.encoder = SwinBackbone(backbone, backbone_pretrained=backbone_pretrained) self.difference = subtraction_fusion # pad at the beginning since smp impl. cuts first two elements off self.decoder = UPerNetDecoder( encoder_channels=[0, 0, *self.encoder.channels], encoder_depth=4, decoder_channels=512, ) # we already have layernorms as part of backbone self.decoder.feature_norms = nn.ModuleList( [nn.Identity() for _ in self.encoder.channels] ) self.final_layer = smp.base.SegmentationHead( in_channels=512, out_channels=classes, activation=None, kernel_size=1, upsampling=0, # avoid here in case of uneven factor ) smp.base.model.init.initialize_decoder(self.decoder) smp.base.model.init.initialize_head(self.final_layer) # padding tensors used for compatibility with UPerNet implementation self.upernet_padding = [torch.tensor(0), torch.tensor(0)]
[docs] def forward(self, x: Tensor) -> Tensor: """BTC forward call. Extract multi-resolution features, fuse by subtraction, decode with UperNet. Args: x: input image tensor (b, t*c, h, w) Returns: binary change map prediction [b, n_cls, h, w]. """ h, w = x.shape[-2:] # change trainer stacks in channel, we want stacked in batch dim for backbone x = rearrange(x, 'b (t c) h w -> (b t) c h w', c=3) # extract multi-resolution features features = self.encoder(x) # feature difference by subtraction fused = self.difference(features) # UperNet impl. skips first 2 feats, we don't want that so we pad with 0 fused = self.upernet_padding + fused # decode to change map x = self.decoder(fused) x = self.final_layer(x) # scale to match input image x = F.interpolate(x, (h, w), mode='bilinear', align_corners=False) return x
class SwinBackbone(Module): """Swin backbone for multi-resolution feature extraction.""" def __init__( self, model_size: str = 'swin_base', backbone_pretrained: bool = False ) -> None: """Initialise swin backbone for multi-resolution feature extraction. Args: model_size: Swin size, one of 'swin_tiny', 'swin_small', or 'swin_base'. backbone_pretrained: whether the cityscapes pretrained swin is used. """ super().__init__() match model_size: case 'swin_tiny': weights = Swin_T_Weights.CITYSCAPES_SEMSEG model = swin_t(weights if backbone_pretrained else None) case 'swin_small': weights = Swin_S_Weights.CITYSCAPES_SEMSEG model = swin_s(weights if backbone_pretrained else None) case 'swin_base': weights = Swin_B_Weights.CITYSCAPES_SEMSEG model = swin_b(weights if backbone_pretrained else None) case _: raise ValueError( f'Invalid swin size: {model_size}. Possible options: swin_[tiny | small | base]' ) # we select layers before reduction! return_layers = ['features.1', 'features.3', 'features.5', 'features.7'] self.feature_extractor = create_feature_extractor( model, return_nodes=return_layers ) self.channels = self._get_feature_channels() self.image_normalization = weights.transforms norms = [] for ch in self.channels: norms.append(nn.LayerNorm(ch)) self.norms = nn.ModuleList(norms) if backbone_pretrained: # load pretrained feature norm weights state_dict = weights.get_state_dict(include_norms=True, progress=True) self.norms.load_state_dict(state_dict['feat_norms_state_dict']) def forward(self, x: Tensor) -> list[Tensor]: """Get multi-resolution features and apply layernorm to each level. Args: x: input image tensor (b*t, c, h, w). Returns: list of multi-resolution feature tensors list[(b*t, c, h', w')]. """ x = self.image_normalization(x) features = self.feature_extractor(x) output = [] for feat, norm in zip(features.values(), self.norms): n, h, w, _c = feat.shape x = norm(feat) x = rearrange(x, 'n h w c -> n c h w', n=n, h=h, w=w) output.append(x) return output def _get_feature_channels(self) -> list[int]: """Get the number of channels in features. Returns: list of channels for each feature map in hierarchy. """ is_training = self.feature_extractor.training # dryrun self.feature_extractor.eval() with torch.no_grad(): features = self.feature_extractor(torch.rand(1, 3, 256, 256)) # revert feature extractor training state self.feature_extractor.train(is_training) # torchvision swin is channel last return [feature.shape[-1] for feature in features.values()] def subtraction_fusion(x: list[Tensor]) -> list[Tensor]: """Bi-temporal feature fusion by elementwise subtraction. Args: x: list of multi-resolution feature tensors list[(b*t c h w)]. Returns: fused feature tensors list[(b c h w)]. """ out_features = [] for feat in x: f1, f2 = rearrange(feat, '(b t) c h w -> t b c h w', t=2) out_features.append(f1 - f2) return out_features