Presto#
- class torchgeo.models.Presto(band_groups=None, encoder_embedding_size=128, channel_embed_ratio=0.25, month_embed_ratio=0.25, encoder_depth=2, mlp_ratio=4, encoder_num_heads=8, decoder_embedding_size=128, decoder_depth=2, decoder_num_heads=8, max_sequence_length=24)[source]#
Bases:
ModulePretrained Remote Sensing Transformer (Presto).
Added in version 0.9.
- __init__(band_groups=None, encoder_embedding_size=128, channel_embed_ratio=0.25, month_embed_ratio=0.25, encoder_depth=2, mlp_ratio=4, encoder_num_heads=8, decoder_embedding_size=128, decoder_depth=2, decoder_num_heads=8, max_sequence_length=24)[source]#
Initialize a new Presto instance.
- Parameters:
band_groups (dict[str, Sequence[int]] | None) – Mapping of band group names to channel indices.
encoder_embedding_size (int) – Size of the embedding for each token in the encoder.
channel_embed_ratio (float) – Ratio of the embedding size to use for channel embeddings in the encoder.
month_embed_ratio (float) – Ratio of the embedding size to use for month embeddings in the encoder.
encoder_depth (int) – Number of Transformer blocks in the encoder.
mlp_ratio (int) – Ratio of the hidden dimension in the MLP compared to the embedding size in the encoder.
encoder_num_heads (int) – Number of attention heads in each Transformer block in the encoder.
decoder_embedding_size (int) – Size of the embedding for each token in the decoder.
decoder_depth (int) – Number of Transformer blocks in the decoder.
decoder_num_heads (int) – Number of attention heads in each Transformer block in the decoder.
max_sequence_length (int) – Maximum length of the input sequence.
- forward(x, dynamic_world, latlons, mask=None, month=0)[source]#
Forward pass of the Presto model.
- Parameters:
x (Tensor) – Input tensor of shape [batch, timesteps, channels].
dynamic_world (Tensor) – Dynamic world tensor of shape [batch, timesteps].
latlons (Tensor) – Latitude and longitude tensor of shape [batch, 2].
mask (Tensor | None) – Mask tensor of shape [batch, timesteps, channels]. Defaults to None.
month (Tensor | int) – Month tensor or integer representing the month. Defaults to 0.
- Returns:
Tuple containing the reconstructed inputs for each channel group and the dynamic world output.
- Return type: