Source code for torchgeo.datasets.presto

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Presto Embeddings dataset."""

from datetime import datetime

import einops
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure

from .geo import RasterDataset
from .utils import Sample


[docs] class PrestoEmbeddings(RasterDataset): """Presto Embeddings dataset. `Geospatial embeddings <https://nasaharvest.github.io/presto-embeddings/>`__ for Togo generated using the `Presto geospatial foundation model <https://arxiv.org/abs/2304.14065>`__. Presto geospatial embeddings provide a compressed representation of Earth Observation data, enabling more efficient mapping and analysis. Embeddings are generated by using the Presto encoder to compress location information, optical imagery (Sentinel-2), radar imagery (Sentinel-1), climatology data (ERA5), and elevation data (SRTM) over the course of a year (March 2019 - March 2020). Each embedding contains 128 features representing a single 10 m2 pixel on Earth. Embeddings can be used in place of raw Earth Observation data for various machine-learning tasks, such as classification, clustering, and anomaly detection. The dataset can be downloaded from one of two sources: * `Hugging Face <https://huggingface.co/datasets/izvonkov/Togo_Presto_Embeddings>`__ * `Google Earth Engine <https://code.earthengine.google.com/?asset=users/izvonkov/Togo/Presto_embeddings_v2025_06_19>`__ If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/2511.02923 .. versionadded:: 0.9 """ filename_glob = 'Togo_Presto_embeddings_*' # Timestamp in filename is when embeddings were released, not the dates they cover # TODO: source code shows ±31 days being added to this range? mint = datetime(2019, 3, 1) maxt = datetime(2020, 3, 1) all_bands = tuple(map(str, range(128)))
[docs] def plot( self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. .. warning:: Visualizations are generated using PCA on each image *individually*, and are thus not comparable across images. The plot method is provided for visualization purposes only and should not be used to draw conclusions. Args: sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: a matplotlib Figure with the rendered sample """ _, h, w = sample['image'].shape A = einops.rearrange(sample['image'], 'c h w -> (h w) c') # Use PCA to project embeddings from 128D to 3D space valid = A.sum(dim=1) != 0 invalid = A.sum(dim=1) == 0 _, _, V = torch.pca_lowrank(A[valid], q=3) B = A @ V B -= B[valid].min(dim=0, keepdim=True)[0] B /= B[valid].max(dim=0, keepdim=True)[0] B[invalid] = 1 image = einops.rearrange(B, '(h w) c -> h w c', h=h, w=w) fig, ax = plt.subplots() ax.imshow(image) ax.axis('off') if show_titles: ax.set_title('Embedding') if suptitle is not None: plt.suptitle(suptitle) return fig