Source code for torchgeo.models.tilenet

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
# Adapted from https://github.com/ermongroup/tile2vec. Copyright (c) 2024 Ermon Group

"""TileNet encoder from Tile2Vec.

Reference:
Jean et al., Tile2Vec: Unsupervised Representation Learning
"""

from typing import Any

import torch
import torch.nn.functional as F
import torchvision.transforms.v2 as T
from torch import nn
from torchvision.models._api import Weights, WeightsEnum


[docs] class TileNet_Weights(WeightsEnum): # type: ignore[misc] """TileNet (Tile2Vec) weights. NAIP-pretrained Tile2Vec encoder. .. versionadded:: 0.9 """ NAIP_ALL_TILE2VEC = Weights( url=( 'https://hf.co/pgangapurwala/TileNet_Weights.NAIP_ALL_TILE2VEC/resolve/af12210f5c130af76579ce8ec5e7036c1551ba25/TileNet_Weights.NAIP_ALL_TILE2VEC.pth' ), transforms=T.Normalize(mean=[0], std=[255], inplace=True), meta={ 'dataset': 'NAIP', 'in_chans': 4, 'model': 'tilenet', 'ssl_method': 'tile2vec', 'publication': 'https://arxiv.org/abs/1805.02855', 'repo': 'https://github.com/ermongroup/tile2vec', 'bands': ('R', 'G', 'B', 'NIR'), }, )
class BasicBlock(nn.Module): """Tile2Vec residual block with extra conv3 branch.""" expansion: int = 1 def __init__(self, in_planes: int, planes: int, stride: int = 1) -> None: """Initialize a BasicBlock. Args: in_planes: Number of input channels. planes: Number of output channels. stride: Convolution stride. """ super().__init__() self.conv1 = nn.Conv2d( in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(planes) # extra conv3/bn3 self.conv3 = nn.Conv2d( planes, planes, kernel_size=3, stride=1, padding=1, bias=True ) self.bn3 = nn.BatchNorm2d(planes) self.shortcut = nn.Sequential() if stride != 1 or in_planes != planes: self.shortcut = nn.Sequential( nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.bn2(self.conv2(out))) out = self.bn3(self.conv3(out)) out += self.shortcut(x) return F.relu(out)
[docs] class TileNet(nn.Module): """TileNet encoder. versionadded:: 0.9 """
[docs] def __init__(self, in_channels: int = 4, z_dim: int = 512) -> None: """Initialize TileNet. Args: in_channels: Number of input channels. z_dim: Output embedding dimension. """ super().__init__() self.in_planes: int = 64 self.conv1 = nn.Conv2d( in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(64) self.layer1 = self._make_layer(64, 2, stride=1) self.layer2 = self._make_layer(128, 2, stride=2) self.layer3 = self._make_layer(256, 2, stride=2) self.layer4 = self._make_layer(512, 2, stride=2) self.layer5 = self._make_layer(z_dim, 2, stride=2)
def _make_layer(self, planes: int, num_blocks: int, stride: int) -> nn.Sequential: """Create a sequential residual layer. Args: planes: Number of output channels for each block. num_blocks: Number of residual blocks in the layer. stride: Stride of the first block. Returns: A sequential container of residual blocks. """ strides = [stride] + [1] * (num_blocks - 1) layers = [] for i, s in enumerate(strides): layers.append(BasicBlock(self.in_planes, planes, stride=s)) self.in_planes = planes return nn.Sequential(*layers)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Compute TileNet embeddings. Args: x: Input image tensor of shape (B, C, H, W). Returns: Embedding tensor of shape (B, embedding_dim). """ x = F.relu(self.bn1(self.conv1(x))) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.layer5(x) x = F.avg_pool2d(x, 4) return x.view(x.size(0), -1)
[docs] def tilenet( weights: TileNet_Weights | None = None, *args: Any, **kwargs: Any ) -> nn.Module: """TileNet (Tile2Vec) encoder. .. versionadded:: 0.9 Args: weights: Pre-trained TileNet weights to load. *args: Positional arguments. **kwargs: Keyword arguments forwarded to model. Returns: A TileNet model. """ if weights: kwargs['in_channels'] = weights.meta['in_chans'] model = TileNet(*args, **kwargs) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=True ) assert missing_keys == [] assert unexpected_keys == [] return model