Source code for torchgeo.datasets.presto
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Presto Embeddings dataset."""
from datetime import datetime
import einops
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from .geo import RasterDataset
from .utils import Sample
[docs]
class PrestoEmbeddings(RasterDataset):
"""Presto Embeddings dataset.
`Geospatial embeddings <https://nasaharvest.github.io/presto-embeddings/>`__ for
Togo generated using the `Presto geospatial foundation model
<https://arxiv.org/abs/2304.14065>`__.
Presto geospatial embeddings provide a compressed representation of Earth
Observation data, enabling more efficient mapping and analysis. Embeddings are
generated by using the Presto encoder to compress location information, optical
imagery (Sentinel-2), radar imagery (Sentinel-1), climatology data (ERA5), and
elevation data (SRTM) over the course of a year (March 2019 - March 2020). Each
embedding contains 128 features representing a single 10 m2 pixel on Earth.
Embeddings can be used in place of raw Earth Observation data for various
machine-learning tasks, such as classification, clustering, and anomaly detection.
The dataset can be downloaded from one of two sources:
* `Hugging Face <https://huggingface.co/datasets/izvonkov/Togo_Presto_Embeddings>`__
* `Google Earth Engine <https://code.earthengine.google.com/?asset=users/izvonkov/Togo/Presto_embeddings_v2025_06_19>`__
If you use this dataset in your research, please cite the following paper:
* https://arxiv.org/abs/2511.02923
.. versionadded:: 0.9
"""
filename_glob = 'Togo_Presto_embeddings_*'
# Timestamp in filename is when embeddings were released, not the dates they cover
# TODO: source code shows ±31 days being added to this range?
mint = datetime(2019, 3, 1)
maxt = datetime(2020, 3, 1)
all_bands = tuple(map(str, range(128)))
[docs]
def plot(
self, sample: Sample, show_titles: bool = True, suptitle: str | None = None
) -> Figure:
"""Plot a sample from the dataset.
.. warning::
Visualizations are generated using PCA on each image *individually*, and
are thus not comparable across images. The plot method is provided for
visualization purposes only and should not be used to draw conclusions.
Args:
sample: a sample returned by :meth:`RasterDataset.__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
Returns:
a matplotlib Figure with the rendered sample
"""
_, h, w = sample['image'].shape
A = einops.rearrange(sample['image'], 'c h w -> (h w) c')
# Use PCA to project embeddings from 128D to 3D space
valid = A.sum(dim=1) != 0
invalid = A.sum(dim=1) == 0
_, _, V = torch.pca_lowrank(A[valid], q=3)
B = A @ V
B -= B[valid].min(dim=0, keepdim=True)[0]
B /= B[valid].max(dim=0, keepdim=True)[0]
B[invalid] = 1
image = einops.rearrange(B, '(h w) c -> h w c', h=h, w=w)
fig, ax = plt.subplots()
ax.imshow(image)
ax.axis('off')
if show_titles:
ax.set_title('Embedding')
if suptitle is not None:
plt.suptitle(suptitle)
return fig