Source code for torchgeo.datasets.esd

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Embedded Seamless Data."""

from collections.abc import Sequence

import torch
from einops import rearrange
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from torch import Tensor

from .geo import RasterDataset
from .utils import GeoSlice, Sample


class ESDQuantizer:
    """Decode ESD-encoded quantized indices into continuous embedding vectors.

    The ESDQuantizerDecoder converts integer quantization indices produced by an
    ESD quantizer into continuous values in the range [-1, 1], representing
    multi-level embeddings of the original input. This enables downstream tasks,
    such as visualization, machine learning, or spatial analysis, to operate
    directly on decoded embeddings without reconstructing the full raw data.

    Key points:

    * Factorized decoding: Each index is split into multiple levels according
      to the quantizer configuration.
    * Continuous mapping: Level indices are rescaled and centered to [-1, 1],
      preserving relative distances in embedding space.
    * Fully vectorized: The decoding is performed on entire tensors at once,
      avoiding slow per-pixel loops and enabling GPU acceleration.
    * Flexible input: Supports arbitrary batch sizes and spatial dimensions
      (..., H, W).

    Usage:

    .. code-block:: python

        decoder = ESDQuantizerDecoder()
        decoded = decoder.apply_transform(torch.from_numpy(ESD_codes.astype(np.int32)))

    .. note::
        The output retains the channel dimension corresponding to embedding levels.
        Users can further convert embeddings to visualizations or aggregate them
        for downstream tasks.

    """

    def __init__(self, levels: Sequence[int] = (8, 8, 8, 5, 5, 5)) -> None:
        """Initialize the quantization levels for the embedding.

        Args:
            levels: Sequence of integers specifying the number of quantization
                levels for each embedding dimension.
        """
        levels_tensor = torch.tensor(levels, dtype=torch.int32)
        self._levels = levels_tensor

        basis_tensor = torch.cumprod(
            torch.tensor([1, *levels[:-1]], dtype=torch.int32), dim=0
        )
        self._basis = basis_tensor

    def indices_to_codes(self, indices: Tensor) -> Tensor:
        """Convert embedding indices to normalized continuous codes.

        Args:
            indices: A tensor of integer indices representing quantized embeddings.
                Shape can be arbitrary (...,).

        Returns:
            Tensor of the same shape as `indices` with an additional channel
            dimension for embedding levels. Values are normalized to [-1, 1].
        """
        indices = rearrange(indices, '... -> ... 1')
        level_indices = (indices // self._basis) % self._levels

        half = self._levels // 2
        codes = (level_indices - half) / half

        return codes

    def quantize(self, input: Tensor) -> Tensor:
        """Quantize the input tensor using predefined levels.

        Args:
            input: A tensor containing integer embedding indices. Shape can be
                arbitrary (...,).

        Returns:
            Tensor of quantized codes with normalized values in [-1, 1], where
            the last dimension represents embedding levels. Channels are moved
            to -3 position for compatibility with expected output layout.
        """
        return self.indices_to_codes(input).movedim(-1, -3)


[docs] class EmbeddedSeamlessData(RasterDataset): """Embedded Seamless Data (ESD). The `Embedded Seamless Data (ESD) <https://arxiv.org/abs/2601.11183>`__ is a global, analysis-ready Earth embedding dataset at 30-meter resolution, designed to overcome the computational and storage challenges of planetary-scale Earth system science. By transforming multi-sensor satellite observations into compact, quantized latent vectors, ESD reduces the original data volume (~1 PB for a full year of global land surfaces) to approximately 2.4 TB, enabling decadal-scale analysis on standard workstations. Key features: * **Longitudinal Consistency**: Provides a continuous record from 2000 to 2024, harmonized from Landsat 5, 7, 8, 9, MODIS Terra and NASADEM imagery. * **High Reconstructive Fidelity**: Achieves a Mean Absolute Error (MAE) of 0.013 across six spectral bands, ensuring the embeddings retain physically meaningful surface information. * **Semantic Intelligence**: Captures complex land surface patterns, outperforming raw sensor fusion data for land-cover classification (global accuracy 79.74%). * **Implicit Denoising**: Filters transient noise such as clouds and shadows via the ESDNet architecture, producing clean signals suitable for temporal and environmental monitoring. * **Few-Shot Proficiency**: Supports robust learning with minimal labeled data, ideal for regions with scarce ground-truth measurements. * **Compact and Vectorized**: Each 30-meter pixel is represented by a high-dimensional embedding vector, which can be aggregated, compared, or analyzed efficiently without reconstructing raw imagery. The dataset covers terrestrial land surfaces, shallow waters, intertidal and reef zones, inland waterways, and coastal regions. Polar coverage is limited by satellite orbits and sensor availability. Produced by the ESDNet framework, ESD provides an ultra-lightweight, globally consistent representation of surface conditions, enabling flexible, high-resolution analysis of land surface dynamics over decades. If you use this dataset in your research, please refer to: * Paper: https://arxiv.org/abs/2601.11183 * Code: https://github.com/shuangchencc/ESD * Dataset: https://data-starcloud.pcl.ac.cn/iearthdata/64 .. versionadded:: 0.9 """ # SDC30_EBD_V001_02VMN_2024.tif filename_glob = 'SDC30_EBD_*' filename_regex = r'.*_(?P<date>\d{4})' date_format = '%Y' quantizer = ESDQuantizer()
[docs] def __getitem__(self, index: GeoSlice) -> Sample: """Retrieve input, target, and/or metadata indexed by spatiotemporal slice. Args: index: [xmin:xmax:xres, ymin:ymax:yres, tmin:tmax:tres] coordinates to index. Returns: Sample of input, target, and/or metadata at that index. Raises: IndexError: If *index* is not found in the dataset. """ sample = super().__getitem__(index) sample['image'] = self.quantizer.quantize(sample['image']) return sample
[docs] def plot( self, sample: Sample, show_titles: bool = True, suptitle: str | None = None ) -> Figure: """Plot a sample from the dataset. Args: sample: Sample dict containing 'image' tensor and metadata. show_titles: Whether to show titles on subplots. suptitle: Optional figure suptitle. Returns: a matplotlib Figure with the rendered sample """ vectors = sample['image'] _months, _channels, H, W = vectors.shape # Compute valid mask: any non-zero pixel across channels valid_mask = ~torch.isclose( vectors, torch.tensor(0.0, device=vectors.device), atol=1e-6 ) valid_mask = ( valid_mask[:12].any(dim=1).any(dim=0) ) # combine first 12 months, shape (H, W) # Reduce channels to RGB using mean over selected channels R = (vectors[:, 5].mean(dim=0) + 1) / 2 # normalize to [0,1] G = (vectors[:, 1].mean(dim=0) + 1) / 2 B = (vectors[:, 2].mean(dim=0) + 1) / 2 # Clamp to [0,1] and convert to uint8 disp_img = torch.zeros(H, W, 4, dtype=torch.uint8, device='cpu') disp_img[..., 0] = (R.clamp(0, 1) * 255).to(torch.uint8).cpu() disp_img[..., 1] = (G.clamp(0, 1) * 255).to(torch.uint8).cpu() disp_img[..., 2] = (B.clamp(0, 1) * 255).to(torch.uint8).cpu() disp_img[..., 3] = valid_mask.to(torch.uint8) * 255 # alpha channel # Plot fig, ax = plt.subplots() ax.imshow(disp_img.cpu().numpy()) ax.axis('off') if show_titles: ax.set_title('ESD Embedding Visualization') if suptitle is not None: plt.suptitle(suptitle) return fig