# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
#
# Copyright (c) 2017 Andrea Palazzi
"""Convolutional Long Short-Term Memory (ConvLSTM) model."""
from typing import cast
import torch
import torch.nn as nn
class ConvLSTMCell(nn.Module):
"""A single ConvLSTM cell module."""
def __init__(
self,
input_dim: int,
hidden_dim: int,
kernel_size: tuple[int, int],
bias: bool = True,
) -> None:
"""Initializes a ConvLSTMCell.
Args:
input_dim: Number of channels of input tensor.
hidden_dim: Number of channels of hidden state.
kernel_size: Size of the convolutional kernel.
bias: Whether or not to add the bias.
"""
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.kernel_size = kernel_size
self.padding = kernel_size[0] // 2, kernel_size[1] // 2
self.bias = bias
self.conv = nn.Conv2d(
in_channels=self.input_dim + self.hidden_dim,
out_channels=4 * self.hidden_dim,
kernel_size=self.kernel_size,
padding=self.padding,
bias=self.bias,
)
def forward(
self, input_tensor: torch.Tensor, cur_state: tuple[torch.Tensor, torch.Tensor]
) -> tuple[torch.Tensor, torch.Tensor]:
"""Forward pass of the ConvLSTMCell.
Args:
input_tensor: Tensor of shape (b, c, h, w).
cur_state: Tuple containing the current hidden and cell states.
Returns:
A tuple containing the next hidden and cell states.
"""
h_cur, c_cur = cur_state
combined = torch.cat([input_tensor, h_cur], dim=1)
combined_conv = self.conv(combined)
cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
i = torch.sigmoid(cc_i)
f = torch.sigmoid(cc_f)
o = torch.sigmoid(cc_o)
g = torch.tanh(cc_g)
c_next = f * c_cur + i * g
h_next = o * torch.tanh(c_next)
return h_next, c_next
def init_hidden(
self, batch_size: int, image_size: tuple[int, int]
) -> tuple[torch.Tensor, torch.Tensor]:
"""Initializes the hidden state.
Args:
batch_size: The batch size.
image_size: The height and width of the image.
Returns:
A tuple of tensors for the initial hidden and cell states.
"""
height, width = image_size
device = self.conv.weight.device
return (
torch.zeros(batch_size, self.hidden_dim, height, width, device=device),
torch.zeros(batch_size, self.hidden_dim, height, width, device=device),
)
[docs]
class ConvLSTM(nn.Module):
"""Convolutional LSTM model.
This model is a sequence-processing model that uses convolutional operations
within the LSTM cells. It is particularly useful for spatio-temporal data.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/1506.04214
.. versionadded:: 0.8
"""
[docs]
def __init__(
self,
input_dim: int,
hidden_dim: int | list[int],
kernel_size: int | tuple[int, int] | list[int | tuple[int, int]],
num_layers: int,
batch_first: bool = True,
bias: bool = True,
return_all_layers: bool = False,
) -> None:
"""Initializes the ConvLSTM model.
Args:
input_dim: Number of channels in the input.
hidden_dim: Number of hidden channels. Can be a single int (for all
layers) or a list of ints (one for each layer).
kernel_size: Size of the convolutional kernel. Can be:
* a single integer (for square kernels)
* a tuple of two integers (for rectangular kernels)
* a list of integers or tuples (one for each layer)
num_layers: Number of LSTM layers stacked on each other.
batch_first: If ``True``, then the input and output tensors are
provided as (b, t, c, h, w).
bias: If ``True``, adds a learnable bias to the output.
return_all_layers: If ``True``, will return the list of computations
for all layers.
"""
super().__init__()
# Normalize hidden_dim to a list of ints
if isinstance(hidden_dim, int):
self.hidden_dim = [hidden_dim] * num_layers
else:
self.hidden_dim = hidden_dim
# Normalize kernel_size to a list of tuples
if isinstance(kernel_size, int | tuple):
ks_list = [kernel_size] * num_layers
else:
ks_list = kernel_size
self.kernel_size = [(ks, ks) if isinstance(ks, int) else ks for ks in ks_list]
if not len(self.kernel_size) == len(self.hidden_dim) == num_layers:
raise ValueError('Inconsistent list length.')
self.input_dim = input_dim
self.num_layers = num_layers
self.batch_first = batch_first
self.bias = bias
self.return_all_layers = return_all_layers
cell_list = []
for i in range(self.num_layers):
cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
cell_list.append(
ConvLSTMCell(
input_dim=cur_input_dim,
hidden_dim=self.hidden_dim[i],
kernel_size=self.kernel_size[i],
bias=self.bias,
)
)
self.cell_list = nn.ModuleList(cell_list)
[docs]
def forward(
self,
input_tensor: torch.Tensor,
hidden_state: list[tuple[torch.Tensor, torch.Tensor]] | None = None,
) -> tuple[list[torch.Tensor], list[tuple[torch.Tensor, torch.Tensor]]]:
"""Forward pass of the ConvLSTM.
Args:
input_tensor: A 5-D Tensor of shape (t, b, c, h, w) or (b, t, c, h, w).
hidden_state: An optional initial hidden state.
Returns:
A tuple containing layer_output_list and last_state_list.
"""
if not self.batch_first:
input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
b, _, _, h, w = input_tensor.size()
if hidden_state is None:
hidden_state = self._init_hidden(batch_size=b, image_size=(h, w))
layer_output_list = []
last_state_list = []
seq_len = input_tensor.size(1)
cur_layer_input = input_tensor
for layer_idx in range(self.num_layers):
h_state, c_state = hidden_state[layer_idx]
output_inner = []
for t in range(seq_len):
h_state, c_state = self.cell_list[layer_idx](
input_tensor=cur_layer_input[:, t, :, :, :],
cur_state=(h_state, c_state),
)
output_inner.append(h_state)
layer_output = torch.stack(output_inner, dim=1)
cur_layer_input = layer_output
layer_output_list.append(layer_output)
last_state_list.append((h_state, c_state))
if not self.return_all_layers:
layer_output_list = layer_output_list[-1:]
last_state_list = last_state_list[-1:]
return layer_output_list, last_state_list
def _init_hidden(
self, batch_size: int, image_size: tuple[int, int]
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""Initializes the hidden states for all layers.
Args:
batch_size: The size of the batch dimension.
image_size: A tuple of (height, width) for the spatial dimensions.
Returns:
A list of tuples, where each tuple contains the hidden state and cell state
tensors for a layer. Each tensor has shape (batch_size, hidden_dim, height, width).
"""
init_states = []
for i in range(self.num_layers):
cell = cast(ConvLSTMCell, self.cell_list[i])
init_states.append(cell.init_hidden(batch_size, image_size))
return init_states