# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""ChangeViT model implementation.
Based on the paper: https://arxiv.org/pdf/2406.12847
"""
from collections.abc import Sequence
from typing import Any
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torch.nn.modules import Module
class DetailCaptureModule(Module):
"""Detail capture module using timm's ResNet18 implementation.
Paper states: 'three residual convolutional blocks (C2-C4) adapted from ResNet18'
that generate 'three-scale detailed features: 1/2, 1/4, and 1/8 resolutions'
with 'channel dimensions of FCi are set to 64, 128, and 256, respectively.'
Uses timm's pretrained ResNet18 with projection layers to match paper specifications.
"""
def __init__(
self, in_channels: int = 6, backbone: str = 'resnet18', pretrained: bool = False
) -> None:
"""Initialize the detail capture module.
Args:
in_channels: Number of input channels (typically 6 for bitemporal RGB).
backbone: Name of the timm backbone model to use.
pretrained: Whether to load pretrained weights from timm.
"""
super().__init__()
self.backbone = timm.create_model(
backbone,
pretrained=pretrained,
features_only=True,
out_indices=[0, 1, 2],
in_chans=in_channels,
)
backbone_channels: list[int] = self.backbone.feature_info.channels() # type: ignore[union-attr, operator]
self.proj1 = nn.Conv2d(backbone_channels[0], 64, kernel_size=1)
self.proj2 = nn.Conv2d(backbone_channels[1], 128, kernel_size=1)
self.proj3 = nn.Conv2d(backbone_channels[2], 256, kernel_size=1)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""Forward pass through detail capture module.
Args:
x: Bitemporal input tensor [B, 2*C, H, W]
Returns:
Tuple of features at 1/2, 1/4, and 1/8 scales with 64, 128, 256 channels
"""
features = self.backbone(x)
c2 = self.proj1(features[0])
c3 = self.proj2(features[1])
c4 = self.proj3(features[2])
return c2, c3, c4
class FeatureInjector(Module):
"""Feature injector using cross-attention to inject detail features into ViT.
Implements the cross-attention mechanism described in the ChangeViT paper,
where ViT features serve as queries and detail features as keys/values.
"""
def __init__(
self,
vit_dim: int,
detail_dims: Sequence[int] = (64, 128, 256),
num_heads: int = 8,
) -> None:
"""Initialize the feature injector.
Args:
vit_dim: Dimension of ViT features
detail_dims: Dimensions of detail features at 3 scales (C2, C3, C4)
num_heads: Number of attention heads
"""
super().__init__()
self.cross_attns = nn.ModuleList(
[
nn.MultiheadAttention(
embed_dim=vit_dim, num_heads=num_heads, batch_first=True
)
for _ in range(3)
]
)
self.detail_projs = nn.ModuleList(
[nn.Linear(dim, vit_dim) for dim in detail_dims]
)
self.fusion = nn.Sequential(
nn.Linear(vit_dim * 4, vit_dim),
nn.ReLU(inplace=True),
nn.Linear(vit_dim, vit_dim),
)
def forward(
self, vit_feats: Tensor, detail_feats: tuple[Tensor, Tensor, Tensor]
) -> Tensor:
"""Inject detail features into ViT features via cross-attention.
Args:
vit_feats: ViT features [B, N, D] where N = H*W/patch_size^2
detail_feats: Tuple of detail features at 3 scales
Returns:
Enhanced ViT features [B, N, D]
"""
_b, n, _d = vit_feats.shape
enhanced_feats = [vit_feats]
patch_grid_size = int(n**0.5)
target_spatial = (patch_grid_size, patch_grid_size)
for i, (detail_feat, cross_attn, proj) in enumerate(
zip(detail_feats, self.cross_attns, self.detail_projs)
):
detail_aligned = F.adaptive_avg_pool2d(detail_feat, target_spatial)
detail_flat = detail_aligned.flatten(2).transpose(1, 2)
detail_proj = proj(detail_flat)
enhanced_feat, _ = cross_attn(
query=vit_feats, key=detail_proj, value=detail_proj
)
enhanced_feats.append(enhanced_feat)
fused = torch.cat(enhanced_feats, dim=-1)
result: Tensor = self.fusion(fused)
return result
class ChangeViTDecoder(Module):
"""Change detection decoder for ViT-based models.
As described in the ChangeViT paper, this decoder handles the final difference
modeling and change map generation from enhanced ViT features.
"""
def __init__(
self,
in_channels: int = 768,
inner_channels: int = 64,
num_convs: int = 3,
num_classes: int = 1,
) -> None:
"""Initialize the ChangeViTDecoder.
Args:
in_channels: Input feature dimension (ViT embedding dim)
inner_channels: Number of inner channels for processing
num_convs: Number of convolutional layers
num_classes: Number of output classes
"""
super().__init__()
layers: list[nn.Module] = [
nn.Sequential(
nn.Conv2d(in_channels * 2, inner_channels, 3, 1, 1),
nn.BatchNorm2d(inner_channels),
nn.ReLU(True),
)
]
layers.extend(
[
nn.Sequential(
nn.Conv2d(inner_channels, inner_channels, 3, 1, 1),
nn.BatchNorm2d(inner_channels),
nn.ReLU(True),
)
for _ in range(num_convs - 1)
]
)
self.convs = nn.Sequential(*layers)
self.head = nn.Conv2d(inner_channels, num_classes, 3, 1, 1)
def forward(self, bi_feature: Tensor) -> tuple[Tensor, Tensor]:
"""Forward pass for change detection.
Args:
bi_feature: Bitemporal features [B, T, C, H, W]
Returns:
Tuple of bidirectional change predictions (logits)
"""
batch_size = bi_feature.size(0)
t1t2 = torch.cat([bi_feature[:, 0], bi_feature[:, 1]], dim=1)
t2t1 = torch.cat([bi_feature[:, 1], bi_feature[:, 0]], dim=1)
features = self.convs(torch.cat([t1t2, t2t1], dim=0))
logits = self.head(features)
c12, c21 = torch.split(logits, batch_size, dim=0)
return c12, c21
[docs]
class ChangeViT(Module):
"""ChangeViT model for change detection.
ChangeViT implementation using plain Vision Transformer as backbone
with detail capture module and feature injection mechanism.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2406.12847
.. note::
For best results on LEVIR-CD as reported in the paper, use:
* Backbone: ``vit_large_patch16_dinov3.sat493m`` (DINOv3-Large pretrained on
satellite imagery)
* Loss: Combined BCE+Dice loss (not yet implemented in ChangeDetectionTask)
* Training: 80k steps with batch size 48
* Image size: 256x256 patches
.. versionadded:: 0.8
"""
[docs]
def __init__(
self,
backbone: str,
img_size: int = 256,
in_channels: int = 3,
num_classes: int = 1,
pretrained: bool = False,
**kwargs: Any,
) -> None:
"""Initialize ChangeViT model.
Args:
backbone: Name of the timm ViT model to use as backbone
(e.g., 'vit_small_patch14_dinov2', 'vit_tiny_patch16_224')
img_size: Input image size (default: 256)
in_channels: Number of input channels per temporal frame (default: 3)
num_classes: Number of output classes (default: 1)
pretrained: Whether to load pretrained weights from timm (default: False)
**kwargs: Additional keyword arguments passed to timm backbone
"""
super().__init__()
self.encoder: Any = timm.create_model(
backbone,
pretrained=pretrained,
num_classes=0,
img_size=img_size,
dynamic_img_size=True,
in_chans=in_channels,
**kwargs,
)
embed_dim: int = self.encoder.embed_dim # type: ignore[assignment]
self.detail_capture = DetailCaptureModule(
in_channels=in_channels * 2, pretrained=pretrained
)
self.feature_injector = FeatureInjector(
vit_dim=embed_dim, detail_dims=(64, 128, 256)
)
self.decoder = ChangeViTDecoder(in_channels=embed_dim, num_classes=num_classes)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Forward pass of ChangeViT.
Args:
x: Bitemporal input tensor [B, T, C, H, W]
Returns:
Change detection logits [B, 1, H, W]
"""
_b, _t, _c, h, w = x.shape
x_t1 = x[:, 0]
x_t2 = x[:, 1]
x_concat = rearrange(x, 'b t c h w -> b (t c) h w')
vit_features_t1 = self.encoder.forward_features(x_t1)
vit_features_t2 = self.encoder.forward_features(x_t2)
detail_features = self.detail_capture(x_concat)
patch_size_attr = self.encoder.patch_embed.patch_size
patch_size = (
patch_size_attr[0]
if isinstance(patch_size_attr, tuple)
else patch_size_attr
)
h_patch, w_patch = h // patch_size, w // patch_size
num_patch_tokens = h_patch * w_patch
patch_features_t1 = vit_features_t1[:, 1 : 1 + num_patch_tokens]
patch_features_t2 = vit_features_t2[:, 1 : 1 + num_patch_tokens]
vit_features_stacked = torch.stack(
[patch_features_t1, patch_features_t2], dim=1
)
enhanced_features_list = []
for t_idx in range(2):
enhanced_feat = self.feature_injector(
vit_features_stacked[:, t_idx], detail_features
)
enhanced_features_list.append(enhanced_feat)
enhanced_features_tensor = torch.stack(enhanced_features_list, dim=1)
enhanced_spatial = rearrange(
enhanced_features_tensor, 'b t (h w) d -> b t d h w', h=h_patch, w=w_patch
)
c12, _c21 = self.decoder(enhanced_spatial)
target_size = (x.shape[-2], x.shape[-1])
change_logits: Tensor = F.interpolate(
c12, size=target_size, mode='bilinear', align_corners=False
)
return change_logits